diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2021-05-04 11:49:14 +0100 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2021-05-04 11:49:14 +0100 |
commit | b8d85c9b7797ce0992323ace55fba9172776a8ab (patch) | |
tree | 5b86f07223f90ee1102aa8b6ba68102e96997054 | |
parent | Changelog (diff) | |
parent | Merge tag 'v1.33.0rc2' into develop (diff) | |
download | synapse-b8d85c9b7797ce0992323ace55fba9172776a8ab.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into anoa/module_api_full_presence_fix
163 files changed, 3219 insertions, 1908 deletions
diff --git a/CHANGES.md b/CHANGES.md index 41908f84be..bdeb614b9e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,11 +1,128 @@ -Synapse 1.32.0rc1 (2021-04-13) +Synapse 1.33.0rc2 (2021-04-29) +============================== + +Bugfixes +-------- + +- Fix tight loop when handling presence replication when using workers. Introduced in v1.33.0rc1. ([\#9900](https://github.com/matrix-org/synapse/issues/9900)) + + +Synapse 1.33.0rc1 (2021-04-28) ============================== +Features +-------- + +- Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. ([\#9800](https://github.com/matrix-org/synapse/issues/9800), [\#9814](https://github.com/matrix-org/synapse/issues/9814)) +- Add experimental support for handling presence on a worker. ([\#9819](https://github.com/matrix-org/synapse/issues/9819), [\#9820](https://github.com/matrix-org/synapse/issues/9820), [\#9828](https://github.com/matrix-org/synapse/issues/9828), [\#9850](https://github.com/matrix-org/synapse/issues/9850)) +- Return a new template when an user attempts to renew their account multiple times with the same token, stating that their account is set to expire. This replaces the invalid token template that would previously be shown in this case. This change concerns the optional account validity feature. ([\#9832](https://github.com/matrix-org/synapse/issues/9832)) + + +Bugfixes +-------- + +- Fixes the OIDC SSO flow when using a `public_baseurl` value including a non-root URL path. ([\#9726](https://github.com/matrix-org/synapse/issues/9726)) +- Fix thumbnail generation for some sites with non-standard content types. Contributed by @rkfg. ([\#9788](https://github.com/matrix-org/synapse/issues/9788)) +- Add some sanity checks to identity server passed to 3PID bind/unbind endpoints. ([\#9802](https://github.com/matrix-org/synapse/issues/9802)) +- Limit the size of HTTP responses read over federation. ([\#9833](https://github.com/matrix-org/synapse/issues/9833)) +- Fix a bug which could cause Synapse to get stuck in a loop of resyncing device lists. ([\#9867](https://github.com/matrix-org/synapse/issues/9867)) +- Fix a long-standing bug where errors from federation did not propagate to the client. ([\#9868](https://github.com/matrix-org/synapse/issues/9868)) + + +Improved Documentation +---------------------- + +- Add a note to the docker docs mentioning that we mirror upstream's supported Docker platforms. ([\#9801](https://github.com/matrix-org/synapse/issues/9801)) + + +Internal Changes +---------------- + +- Add a dockerfile for running Synapse in worker-mode under Complement. ([\#9162](https://github.com/matrix-org/synapse/issues/9162)) +- Apply `pyupgrade` across the codebase. ([\#9786](https://github.com/matrix-org/synapse/issues/9786)) +- Move some replication processing out of `generic_worker`. ([\#9796](https://github.com/matrix-org/synapse/issues/9796)) +- Replace `HomeServer.get_config()` with inline references. ([\#9815](https://github.com/matrix-org/synapse/issues/9815)) +- Rename some handlers and config modules to not duplicate the top-level module. ([\#9816](https://github.com/matrix-org/synapse/issues/9816)) +- Fix a long-standing bug which caused `max_upload_size` to not be correctly enforced. ([\#9817](https://github.com/matrix-org/synapse/issues/9817)) +- Reduce CPU usage of the user directory by reusing existing calculated room membership. ([\#9821](https://github.com/matrix-org/synapse/issues/9821)) +- Small speed up for joining large remote rooms. ([\#9825](https://github.com/matrix-org/synapse/issues/9825)) +- Introduce flake8-bugbear to the test suite and fix some of its lint violations. ([\#9838](https://github.com/matrix-org/synapse/issues/9838)) +- Only store the raw data in the in-memory caches, rather than objects that include references to e.g. the data stores. ([\#9845](https://github.com/matrix-org/synapse/issues/9845)) +- Limit length of accepted email addresses. ([\#9855](https://github.com/matrix-org/synapse/issues/9855)) +- Remove redundant `synapse.types.Collection` type definition. ([\#9856](https://github.com/matrix-org/synapse/issues/9856)) +- Handle recently added rate limits correctly when using `--no-rate-limit` with the demo scripts. ([\#9858](https://github.com/matrix-org/synapse/issues/9858)) +- Disable invite rate-limiting by default when running the unit tests. ([\#9871](https://github.com/matrix-org/synapse/issues/9871)) +- Pass a reactor into `SynapseSite` to make testing easier. ([\#9874](https://github.com/matrix-org/synapse/issues/9874)) +- Make `DomainSpecificString` an `attrs` class. ([\#9875](https://github.com/matrix-org/synapse/issues/9875)) +- Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules. ([\#9876](https://github.com/matrix-org/synapse/issues/9876)) +- Remove redundant `_PushHTTPChannel` test class. ([\#9878](https://github.com/matrix-org/synapse/issues/9878)) +- Remove backwards-compatibility code for Python versions < 3.6. ([\#9879](https://github.com/matrix-org/synapse/issues/9879)) +- Small performance improvement around handling new local presence updates. ([\#9887](https://github.com/matrix-org/synapse/issues/9887)) + + +Synapse 1.32.2 (2021-04-22) +=========================== + +This release includes a fix for a regression introduced in 1.32.0. + +Bugfixes +-------- + +- Fix a regression in Synapse 1.32.0 and 1.32.1 which caused `LoggingContext` errors in plugins. ([\#9857](https://github.com/matrix-org/synapse/issues/9857)) + + +Synapse 1.32.1 (2021-04-21) +=========================== + +This release fixes [a regression](https://github.com/matrix-org/synapse/issues/9853) +in Synapse 1.32.0 that caused connected Prometheus instances to become unstable. + +However, as this release is still subject to the `LoggingContext` change in 1.32.0, +it is recommended to remain on or downgrade to 1.31.0. + +Bugfixes +-------- + +- Fix a regression in Synapse 1.32.0 which caused Synapse to report large numbers of Prometheus time series, potentially overwhelming Prometheus instances. ([\#9854](https://github.com/matrix-org/synapse/issues/9854)) + + +Synapse 1.32.0 (2021-04-20) +=========================== + +**Note:** This release introduces [a regression](https://github.com/matrix-org/synapse/issues/9853) +that can overwhelm connected Prometheus instances. This issue was not present in +1.32.0rc1. If affected, it is recommended to downgrade to 1.31.0 in the meantime, and +follow [these instructions](https://github.com/matrix-org/synapse/pull/9854#issuecomment-823472183) +to clean up any excess writeahead logs. + +**Note:** This release also mistakenly included a change that may affected Synapse +modules that import `synapse.logging.context.LoggingContext`, such as +[synapse-s3-storage-provider](https://github.com/matrix-org/synapse-s3-storage-provider). +This will be fixed in a later Synapse version. + **Note:** This release requires Python 3.6+ and Postgres 9.6+ or SQLite 3.22+. This release removes the deprecated `GET /_synapse/admin/v1/users/<user_id>` admin API. Please use the [v2 API](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/user_admin_api.rst#query-user-account) instead, which has improved capabilities. -This release requires Application Services to use type `m.login.application_services` when registering users via the `/_matrix/client/r0/register` endpoint to comply with the spec. Please ensure your Application Services are up to date. +This release requires Application Services to use type `m.login.application_service` when registering users via the `/_matrix/client/r0/register` endpoint to comply with the spec. Please ensure your Application Services are up to date. + +If you are using the `packages.matrix.org` Debian repository for Synapse packages, +note that we have recently updated the expiry date on the gpg signing key. If you see an +error similar to `The following signatures were invalid: EXPKEYSIG F473DD4473365DE1`, you +will need to get a fresh copy of the keys. You can do so with: + +```sh +sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg +``` + +Bugfixes +-------- + +- Fix the log lines of nested logging contexts. Broke in 1.32.0rc1. ([\#9829](https://github.com/matrix-org/synapse/issues/9829)) + + +Synapse 1.32.0rc1 (2021-04-13) +============================== Features -------- diff --git a/UPGRADE.rst b/UPGRADE.rst index 665821d4ef..e921e0c08a 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -85,9 +85,52 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.33.0 +==================== + +Account Validity HTML templates can now display a user's expiration date +------------------------------------------------------------------------ + +This may affect you if you have enabled the account validity feature, and have made use of a +custom HTML template specified by the ``account_validity.template_dir`` or ``account_validity.account_renewed_html_path`` +Synapse config options. + +The template can now accept an ``expiration_ts`` variable, which represents the unix timestamp in milliseconds for the +future date of which their account has been renewed until. See the +`default template <https://github.com/matrix-org/synapse/blob/release-v1.33.0/synapse/res/templates/account_renewed.html>`_ +for an example of usage. + +ALso note that a new HTML template, ``account_previously_renewed.html``, has been added. This is is shown to users +when they attempt to renew their account with a valid renewal token that has already been used before. The default +template contents can been found +`here <https://github.com/matrix-org/synapse/blob/release-v1.33.0/synapse/res/templates/account_previously_renewed.html>`_, +and can also accept an ``expiration_ts`` variable. This template replaces the error message users would previously see +upon attempting to use a valid renewal token more than once. + + Upgrading to v1.32.0 ==================== +Regression causing connected Prometheus instances to become overwhelmed +----------------------------------------------------------------------- + +This release introduces `a regression <https://github.com/matrix-org/synapse/issues/9853>`_ +that can overwhelm connected Prometheus instances. This issue is not present in +Synapse v1.32.0rc1. + +If you have been affected, please downgrade to 1.31.0. You then may need to +remove excess writeahead logs in order for Prometheus to recover. Instructions +for doing so are provided +`here <https://github.com/matrix-org/synapse/pull/9854#issuecomment-823472183>`_. + +Dropping support for old Python, Postgres and SQLite versions +------------------------------------------------------------- + +In line with our `deprecation policy <https://github.com/matrix-org/synapse/blob/release-v1.32.0/docs/deprecation_policy.md>`_, +we've dropped support for Python 3.5 and PostgreSQL 9.5, as they are no longer supported upstream. + +This release of Synapse requires Python 3.6+ and PostgresSQL 9.6+ or SQLite 3.22+. + Removal of old List Accounts Admin API -------------------------------------- @@ -98,6 +141,16 @@ has been available since Synapse 1.7.0 (2019-12-13), and is accessible under ``G The deprecation of the old endpoint was announced with Synapse 1.28.0 (released on 2021-02-25). +Application Services must use type ``m.login.application_service`` when registering users +----------------------------------------------------------------------------------------- + +In compliance with the +`Application Service spec <https://matrix.org/docs/spec/application_service/r0.1.2#server-admin-style-permissions>`_, +Application Services are now required to use the ``m.login.application_service`` type when registering users via the +``/_matrix/client/r0/register`` endpoint. This behaviour was deprecated in Synapse v1.30.0. + +Please ensure your Application Services are up to date. + Upgrading to v1.29.0 ==================== diff --git a/changelog.d/9162.misc b/changelog.d/9162.misc deleted file mode 100644 index 1083da8a7a..0000000000 --- a/changelog.d/9162.misc +++ /dev/null @@ -1 +0,0 @@ -Add a dockerfile for running Synapse in worker-mode under Complement. \ No newline at end of file diff --git a/changelog.d/9702.misc b/changelog.d/9702.misc deleted file mode 100644 index c6e63450a9..0000000000 --- a/changelog.d/9702.misc +++ /dev/null @@ -1 +0,0 @@ -Speed up federation transmission by using fewer database calls. Contributed by @ShadowJonathan. diff --git a/changelog.d/9786.misc b/changelog.d/9786.misc deleted file mode 100644 index cf265db749..0000000000 --- a/changelog.d/9786.misc +++ /dev/null @@ -1 +0,0 @@ -Apply `pyupgrade` across the codebase. \ No newline at end of file diff --git a/changelog.d/9788.bugfix b/changelog.d/9788.bugfix deleted file mode 100644 index edb58fbd5b..0000000000 --- a/changelog.d/9788.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix thumbnail generation for some sites with non-standard content types. Contributed by @rkfg. diff --git a/changelog.d/9796.misc b/changelog.d/9796.misc deleted file mode 100644 index 59bb1813c3..0000000000 --- a/changelog.d/9796.misc +++ /dev/null @@ -1 +0,0 @@ -Move some replication processing out of `generic_worker`. diff --git a/changelog.d/9800.feature b/changelog.d/9800.feature deleted file mode 100644 index 9404ad2fc0..0000000000 --- a/changelog.d/9800.feature +++ /dev/null @@ -1 +0,0 @@ -Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/changelog.d/9801.doc b/changelog.d/9801.doc deleted file mode 100644 index 8b8b9d01d4..0000000000 --- a/changelog.d/9801.doc +++ /dev/null @@ -1 +0,0 @@ -Add a note to the docker docs mentioning that we mirror upstream's supported Docker platforms. diff --git a/changelog.d/9815.misc b/changelog.d/9815.misc deleted file mode 100644 index e33d012d3d..0000000000 --- a/changelog.d/9815.misc +++ /dev/null @@ -1 +0,0 @@ -Replace `HomeServer.get_config()` with inline references. diff --git a/changelog.d/9885.misc b/changelog.d/9885.misc new file mode 100644 index 0000000000..492fccea46 --- /dev/null +++ b/changelog.d/9885.misc @@ -0,0 +1 @@ +Add type hints to presence handler. diff --git a/changelog.d/9886.misc b/changelog.d/9886.misc new file mode 100644 index 0000000000..8ff869e659 --- /dev/null +++ b/changelog.d/9886.misc @@ -0,0 +1 @@ +Reduce memory usage of the LRU caches. diff --git a/changelog.d/9889.feature b/changelog.d/9889.feature new file mode 100644 index 0000000000..74d46f222e --- /dev/null +++ b/changelog.d/9889.feature @@ -0,0 +1 @@ +Add support for `DELETE /_synapse/admin/v1/rooms/<room_id>`. \ No newline at end of file diff --git a/changelog.d/9889.removal b/changelog.d/9889.removal new file mode 100644 index 0000000000..398b9e129b --- /dev/null +++ b/changelog.d/9889.removal @@ -0,0 +1 @@ +Mark as deprecated `POST /_synapse/admin/v1/rooms/<room_id>/delete`. \ No newline at end of file diff --git a/changelog.d/9895.bugfix b/changelog.d/9895.bugfix new file mode 100644 index 0000000000..1053f975bf --- /dev/null +++ b/changelog.d/9895.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.32.0 where the associated connection was improperly logged for SQL logging statements. diff --git a/changelog.d/9896.bugfix b/changelog.d/9896.bugfix new file mode 100644 index 0000000000..07a8e87f9f --- /dev/null +++ b/changelog.d/9896.bugfix @@ -0,0 +1 @@ +Correct the type hint for the `user_may_create_room_alias` method of spam checkers. It is provided a `RoomAlias`, not a `str`. diff --git a/changelog.d/9896.misc b/changelog.d/9896.misc new file mode 100644 index 0000000000..e41c7d1f02 --- /dev/null +++ b/changelog.d/9896.misc @@ -0,0 +1 @@ +Add type hints to the `synapse.handlers` module. diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 5dd172052b..31b8a68225 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -224,16 +224,14 @@ class HomeServer(ReplicationHandler): destinations = yield self.get_servers_for_context(room_name) try: - yield self.replication_layer.send_pdus( - [ - Pdu.create_new( - context=room_name, - pdu_type="sy.room.message", - content={"sender": sender, "body": body}, - origin=self.server_name, - destinations=destinations, - ) - ] + yield self.replication_layer.send_pdu( + Pdu.create_new( + context=room_name, + pdu_type="sy.room.message", + content={"sender": sender, "body": body}, + origin=self.server_name, + destinations=destinations, + ) ) except Exception as e: logger.exception(e) @@ -255,7 +253,7 @@ class HomeServer(ReplicationHandler): origin=self.server_name, destinations=destinations, ) - yield self.replication_layer.send_pdus([pdu]) + yield self.replication_layer.send_pdu(pdu) except Exception as e: logger.exception(e) @@ -267,18 +265,16 @@ class HomeServer(ReplicationHandler): destinations = yield self.get_servers_for_context(room_name) try: - yield self.replication_layer.send_pdus( - [ - Pdu.create_new( - context=room_name, - is_state=True, - pdu_type="sy.room.member", - state_key=invitee, - content={"membership": "invite"}, - origin=self.server_name, - destinations=destinations, - ) - ] + yield self.replication_layer.send_pdu( + Pdu.create_new( + context=room_name, + is_state=True, + pdu_type="sy.room.member", + state_key=invitee, + content={"membership": "invite"}, + origin=self.server_name, + destinations=destinations, + ) ) except Exception as e: logger.exception(e) diff --git a/debian/changelog b/debian/changelog index 5d526316fc..fd33bfda5c 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,8 +1,24 @@ -matrix-synapse-py3 (1.31.0+nmu1) UNRELEASED; urgency=medium +matrix-synapse-py3 (1.32.2) stable; urgency=medium + * New synapse release 1.32.2. + + -- Synapse Packaging team <packages@matrix.org> Wed, 22 Apr 2021 12:43:52 +0100 + +matrix-synapse-py3 (1.32.1) stable; urgency=medium + + * New synapse release 1.32.1. + + -- Synapse Packaging team <packages@matrix.org> Wed, 21 Apr 2021 14:00:55 +0100 + +matrix-synapse-py3 (1.32.0) stable; urgency=medium + + [ Dan Callahan ] * Skip tests when DEB_BUILD_OPTIONS contains "nocheck". - -- Dan Callahan <danc@element.io> Mon, 12 Apr 2021 13:07:36 +0000 + [ Synapse Packaging team ] + * New synapse release 1.32.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 20 Apr 2021 14:28:39 +0100 matrix-synapse-py3 (1.31.0) stable; urgency=medium diff --git a/demo/start.sh b/demo/start.sh index 621a5698b8..bc4854091b 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -96,18 +96,48 @@ for port in 8080 8081 8082; do # Check script parameters if [ $# -eq 1 ]; then if [ $1 = "--no-rate-limit" ]; then - # messages rate limit - echo 'rc_messages_per_second: 1000' >> $DIR/etc/$port.config - echo 'rc_message_burst_count: 1000' >> $DIR/etc/$port.config - - # registration rate limit - printf 'rc_registration:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config - - # login rate limit - echo 'rc_login:' >> $DIR/etc/$port.config - printf ' address:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config - printf ' account:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config - printf ' failed_attempts:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config + + # Disable any rate limiting + ratelimiting=$(cat <<-RC + rc_message: + per_second: 1000 + burst_count: 1000 + rc_registration: + per_second: 1000 + burst_count: 1000 + rc_login: + address: + per_second: 1000 + burst_count: 1000 + account: + per_second: 1000 + burst_count: 1000 + failed_attempts: + per_second: 1000 + burst_count: 1000 + rc_admin_redaction: + per_second: 1000 + burst_count: 1000 + rc_joins: + local: + per_second: 1000 + burst_count: 1000 + remote: + per_second: 1000 + burst_count: 1000 + rc_3pid_validation: + per_second: 1000 + burst_count: 1000 + rc_invites: + per_room: + per_second: 1000 + burst_count: 1000 + per_user: + per_second: 1000 + burst_count: 1000 + RC + ) + echo "${ratelimiting}" >> $DIR/etc/$port.config fi fi diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index bc737b30f5..01d3882426 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -427,7 +427,7 @@ the new room. Users on other servers will be unaffected. The API is: ``` -POST /_synapse/admin/v1/rooms/<room_id>/delete +DELETE /_synapse/admin/v1/rooms/<room_id> ``` with a body of: @@ -528,6 +528,15 @@ You will have to manually handle, if you so choose, the following: * Users that would have been booted from the room (and will have been force-joined to the Content Violation room). * Removal of the Content Violation room if desired. +## Deprecated endpoint + +The previous deprecated API will be removed in a future release, it was: + +``` +POST /_synapse/admin/v1/rooms/<room_id>/delete +``` + +It behaves the same way than the current endpoint except the path and the method. # Make Room Admin API diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9182dcd987..e0350279ad 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1175,69 +1175,6 @@ url_preview_accept_language: # #enable_registration: false -# Optional account validity configuration. This allows for accounts to be denied -# any request after a given period. -# -# Once this feature is enabled, Synapse will look for registered users without an -# expiration date at startup and will add one to every account it found using the -# current settings at that time. -# This means that, if a validity period is set, and Synapse is restarted (it will -# then derive an expiration date from the current validity period), and some time -# after that the validity period changes and Synapse is restarted, the users' -# expiration dates won't be updated unless their account is manually renewed. This -# date will be randomly selected within a range [now + period - d ; now + period], -# where d is equal to 10% of the validity period. -# -account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - #template_dir: "res/templates" - - # File within 'template_dir' giving the HTML to be displayed to the user after - # they successfully renewed their account. If not set, default text is used. - # - #account_renewed_html_path: "account_renewed.html" - - # File within 'template_dir' giving the HTML to be displayed when the user - # tries to renew an account with an invalid renewal token. If not set, - # default text is used. - # - #invalid_token_html_path: "invalid_token.html" - # Time that a user's session remains valid for, after they log in. # # Note that this is not currently compatible with guest logins. @@ -1432,6 +1369,91 @@ account_threepid_delegates: #auto_join_rooms_for_guests: false +## Account Validity ## + +# Optional account validity configuration. This allows for accounts to be denied +# any request after a given period. +# +# Once this feature is enabled, Synapse will look for registered users without an +# expiration date at startup and will add one to every account it found using the +# current settings at that time. +# This means that, if a validity period is set, and Synapse is restarted (it will +# then derive an expiration date from the current validity period), and some time +# after that the validity period changes and Synapse is restarted, the users' +# expiration dates won't be updated unless their account is manually renewed. This +# date will be randomly selected within a range [now + period - d ; now + period], +# where d is equal to 10% of the validity period. +# +account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + # The currently available templates are: + # + # * account_renewed.html: Displayed to the user after they have successfully + # renewed their account. + # + # * account_previously_renewed.html: Displayed to the user if they attempt to + # renew their account with a token that is valid, but that has already + # been used. In this case the account is not renewed again. + # + # * invalid_token.html: Displayed to the user when they try to renew an account + # with an unknown or invalid renewal token. + # + # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for + # default template contents. + # + # The file name of some of these templates can be configured below for legacy + # reasons. + # + #template_dir: "res/templates" + + # A custom file name for the 'account_renewed.html' template. + # + # If not set, the file is assumed to be named "account_renewed.html". + # + #account_renewed_html_path: "account_renewed.html" + + # A custom file name for the 'invalid_token.html' template. + # + # If not set, the file is assumed to be named "invalid_token.html". + # + #invalid_token_html_path: "invalid_token.html" + + ## Metrics ### # Enable collection and rendering of performance metrics @@ -1878,7 +1900,7 @@ saml2_config: # sub-properties: # # module: The class name of a custom mapping module. Default is -# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. +# 'synapse.handlers.oidc.JinjaOidcMappingProvider'. # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers # for information on implementing a custom mapping provider. # diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index e1d6ede7ba..50020d1a4a 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -106,7 +106,7 @@ A custom mapping provider must specify the following methods: Synapse has a built-in OpenID mapping provider if a custom provider isn't specified in the config. It is located at -[`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`](../synapse/handlers/oidc_handler.py). +[`synapse.handlers.oidc.JinjaOidcMappingProvider`](../synapse/handlers/oidc.py). ## SAML Mapping Providers @@ -190,4 +190,4 @@ A custom mapping provider must specify the following methods: Synapse has a built-in SAML mapping provider if a custom provider isn't specified in the config. It is located at -[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py). +[`synapse.handlers.saml.DefaultSamlMappingProvider`](../synapse/handlers/saml.py). diff --git a/mypy.ini b/mypy.ini index 32e6197409..a40f705b76 100644 --- a/mypy.ini +++ b/mypy.ini @@ -41,7 +41,6 @@ files = synapse/push, synapse/replication, synapse/rest, - synapse/secrets.py, synapse/server.py, synapse/server_notices, synapse/spam_checker_api, diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 313860df13..c82ddd9677 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -140,7 +140,7 @@ if __name__ == "__main__": definitions = {} for directory in args.directories: - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for filename in files: if filename.endswith(".py"): filepath = os.path.join(root, filename) diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index 26ad7c67f4..e85420dea8 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -48,7 +48,7 @@ args = parser.parse_args() for directory in args.directories: - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for filename in files: if filename.endswith(".py"): filepath = os.path.join(root, filename) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index b7c1ffc956..f0c93d5226 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -634,8 +634,11 @@ class Porter(object): "device_inbox_sequence", ("device_inbox", "device_federation_outbox") ) await self._setup_sequence( - "account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data")) - await self._setup_sequence("receipts_sequence", ("receipts_linearized", )) + "account_data_sequence", + ("room_account_data", "room_tags_revisions", "account_data"), + ) + await self._setup_sequence("receipts_sequence", ("receipts_linearized",)) + await self._setup_sequence("presence_stream_sequence", ("presence_stream",)) await self._setup_auth_chain_sequence() # Step 3. Get tables. diff --git a/setup.cfg b/setup.cfg index 33601b71d5..e5ceb7ed19 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,8 +18,7 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -# B007: Subsection of the bugbear suite (TODO: add in remaining fixes) -ignore=W503,W504,E203,E731,E501,B007 +ignore=W503,W504,E203,E731,E501 [isort] line_length = 88 diff --git a/synapse/__init__.py b/synapse/__init__.py index c2709c2ad4..319c52be2c 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -21,8 +21,8 @@ import os import sys # Check that we're not running on an unsupported Python version. -if sys.version_info < (3, 5): - print("Synapse requires Python 3.5 or above.") +if sys.version_info < (3, 6): + print("Synapse requires Python 3.6 or above.") sys.exit(1) # Twisted and canonicaljson will fail to import when this file is executed to @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.32.0rc1" +__version__ = "1.33.0rc2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6c13f53957..efc926d094 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import pymacaroons from netaddr import IPAddress from twisted.web.server import Request -import synapse.types from synapse import event_auth from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import EventTypes, HistoryVisibility, Membership @@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import StateMap, UserID +from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -65,9 +67,10 @@ class Auth: """ FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. + The latter should be moved to synapse.handlers.event_auth.EventAuthHandler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -79,19 +82,21 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity = hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key async def check_from_context( self, room_version: str, event, context, do_sig_check=True - ): + ) -> None: prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_by_id = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check( @@ -148,17 +153,11 @@ class Auth: raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - async def check_host_in_room(self, room_id, host): + async def check_host_in_room(self, room_id: str, host: str) -> bool: with Measure(self.clock, "check_host_in_room"): - latest_event_ids = await self.store.is_host_joined(room_id, host) - return latest_event_ids - - def can_federate(self, event, auth_events): - creation_event = auth_events.get((EventTypes.Create, "")) + return await self.store.is_host_joined(room_id, host) - return creation_event.content.get("m.federate", True) is True - - def get_public_keys(self, invite_event): + def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]: return event_auth.get_public_keys(invite_event) async def get_user_by_req( @@ -167,7 +166,7 @@ class Auth: allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, - ) -> synapse.types.Requester: + ) -> Requester: """Get a registered user's ID. Args: @@ -193,7 +192,7 @@ class Auth: access_token = self.get_access_token_from_request(request) user_id, app_service = await self._get_appservice_user_id(request) - if user_id: + if user_id and app_service: if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( user_id=user_id, @@ -203,9 +202,7 @@ class Auth: device_id="dummy-device", # stubbed ) - requester = synapse.types.create_requester( - user_id, app_service=app_service - ) + requester = create_requester(user_id, app_service=app_service) request.requester = user_id opentracing.set_tag("authenticated_entity", user_id) @@ -222,7 +219,7 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity.enabled and not allow_expired: + if self._account_validity_enabled and not allow_expired: if await self.store.is_account_expired( user_info.user_id, self.clock.time_msec() ): @@ -248,7 +245,7 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - requester = synapse.types.create_requester( + requester = create_requester( user_info.user_id, token_id, is_guest, @@ -268,7 +265,9 @@ class Auth: except KeyError: raise MissingClientTokenError() - async def _get_appservice_user_id(self, request): + async def _get_appservice_user_id( + self, request: Request + ) -> Tuple[Optional[str], Optional[ApplicationService]]: app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) @@ -280,6 +279,9 @@ class Auth: if ip_address not in app_service.ip_range_whitelist: return None, None + # This will always be set by the time Twisted calls us. + assert request.args is not None + if b"user_id" not in request.args: return app_service.sender, app_service @@ -384,7 +386,9 @@ class Auth: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) raise InvalidClientTokenError("Invalid macaroon passed.") - def _parse_and_validate_macaroon(self, token, rights="access"): + def _parse_and_validate_macaroon( + self, token: str, rights: str = "access" + ) -> Tuple[str, bool]: """Takes a macaroon and tries to parse and validate it. This is cached if and only if rights == access and there isn't an expiry. @@ -429,15 +433,16 @@ class Auth: return user_id, guest - def validate_macaroon(self, macaroon, type_string, user_id): + def validate_macaroon( + self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str + ) -> None: """ validate that a Macaroon is understood by and was signed by this server. Args: - macaroon(pymacaroons.Macaroon): The macaroon to validate - type_string(str): The kind of token required (e.g. "access", - "delete_pusher") - user_id (str): The user_id required + macaroon: The macaroon to validate + type_string: The kind of token required (e.g. "access", "delete_pusher") + user_id: The user_id required """ v = pymacaroons.Verifier() @@ -462,9 +467,7 @@ class Auth: if not service: logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() - request.requester = synapse.types.create_requester( - service.sender, app_service=service - ) + request.requester = create_requester(service.sender, app_service=service) return service async def is_server_admin(self, user: UserID) -> bool: @@ -516,7 +519,7 @@ class Auth: return auth_ids - async def check_can_change_room_list(self, room_id: str, user: UserID): + async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: """Determine whether the user is allowed to edit the room's entry in the published room list. @@ -551,11 +554,11 @@ class Auth: return user_level >= send_level @staticmethod - def has_access_token(request: Request): + def has_access_token(request: Request) -> bool: """Checks if the request has an access_token. Returns: - bool: False if no access_token was given, True otherwise. + False if no access_token was given, True otherwise. """ # This will always be set by the time Twisted calls us. assert request.args is not None @@ -565,13 +568,13 @@ class Auth: return bool(query_params) or bool(auth_headers) @staticmethod - def get_access_token_from_request(request: Request): + def get_access_token_from_request(request: Request) -> str: """Extracts the access_token from the request. Args: request: The http request. Returns: - unicode: The access_token + The access_token Raises: MissingClientTokenError: If there isn't a single access_token in the request @@ -646,5 +649,5 @@ class Auth: % (user_id, room_id), ) - def check_auth_blocking(self, *args, **kwargs): - return self._auth_blocking.check_auth_blocking(*args, **kwargs) + async def check_auth_blocking(self, *args, **kwargs) -> None: + await self._auth_blocking.check_auth_blocking(*args, **kwargs) diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index a8df60cb89..e6bced93d5 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -13,18 +13,21 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved from synapse.types import Requester +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class AuthBlocking: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self._server_notices_mxid = hs.config.server_notices_mxid @@ -43,7 +46,7 @@ class AuthBlocking: threepid: Optional[dict] = None, user_type: Optional[str] = None, requester: Optional[Requester] = None, - ): + ) -> None: """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 31a59bceec..936b6534b4 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,9 @@ """Contains constants from the specification.""" +# the max size of a (canonical-json-encoded) event +MAX_PDU_SIZE = 65536 + # the "depth" field on events is limited to 2**63 - 1 MAX_DEPTH = 2 ** 63 - 1 diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2113c4f370..638e01c1b2 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -30,9 +30,10 @@ from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory import synapse +from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home -from synapse.config.server import ListenerConfig +from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -288,7 +289,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): +async def start(hs: "synapse.server.HomeServer"): """ Start a Synapse server or worker. @@ -300,7 +301,6 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon Args: hs: homeserver instance - listeners: Listener configuration ('listeners' in homeserver.yaml) """ # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): @@ -336,7 +336,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa # It is now safe to start your Synapse. - hs.start_listening(listeners) + hs.start_listening() hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() @@ -530,3 +530,25 @@ def sdnotify(state): # this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET # unless systemd is expecting us to notify it. logger.warning("Unable to send notification to systemd: %s", e) + + +def max_request_body_size(config: HomeServerConfig) -> int: + """Get a suitable maximum size for incoming HTTP requests""" + + # Other than media uploads, the biggest request we expect to see is a fully-loaded + # /federation/v1/send request. + # + # The main thing in such a request is up to 50 PDUs, and up to 100 EDUs. PDUs are + # limited to 65536 bytes (possibly slightly more if the sender didn't use canonical + # json encoding); there is no specced limit to EDUs (see + # https://github.com/matrix-org/matrix-doc/issues/3121). + # + # in short, we somewhat arbitrarily limit requests to 200 * 64K (about 12.5M) + # + max_request_size = 200 * MAX_PDU_SIZE + + # if we have a media repo enabled, we may need to allow larger uploads than that + if config.media.can_load_media_repo: + max_request_size = max(max_request_size, config.media.max_upload_size) + + return max_request_size diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index eb256db749..68ae19c977 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -70,12 +70,6 @@ class AdminCmdSlavedStore( class AdminCmdServer(HomeServer): DATASTORE_CLASS = AdminCmdSlavedStore - def _listen_http(self, listener_config): - pass - - def start_listening(self, listeners): - pass - async def export_data_command(hs, args): """Export data for a user. @@ -232,7 +226,7 @@ def start(config_options): async def run(): with LoggingContext("command"): - _base.start(ss, []) + _base.start(ss) await args.func(ss, args) _base.start_worker_reactor( diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 26c458dbb6..1a15ceee81 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -15,7 +15,7 @@ # limitations under the License. import logging import sys -from typing import Dict, Iterable, Optional +from typing import Dict, Optional from twisted.internet import address from twisted.web.resource import IResource @@ -32,7 +32,7 @@ from synapse.api.urls import ( SERVER_KEY_V2_PREFIX, ) from synapse.app import _base -from synapse.app._base import register_start +from synapse.app._base import max_request_body_size, register_start from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging @@ -55,7 +55,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.profile import SlavedProfileStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.pushers import SlavedPusherStore @@ -64,7 +63,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client.v1 import events, login, presence, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.profile import ( ProfileAvatarURLRestServlet, @@ -110,6 +109,7 @@ from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) +from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.search import SearchWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore @@ -121,26 +121,6 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") -class PresenceStatusStubServlet(RestServlet): - """If presence is disabled this servlet can be used to stub out setting - presence status. - """ - - PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - await self.auth.get_user_by_req(request) - return 200, {"presence": "offline"} - - async def on_PUT(self, request, user_id): - await self.auth.get_user_by_req(request) - return 200, {} - - class KeyUploadServlet(RestServlet): """An implementation of the `KeyUploadServlet` that responds to read only requests, but otherwise proxies through to the master instance. @@ -241,6 +221,7 @@ class GenericWorkerSlavedStore( StatsStore, UIAuthWorkerStore, EndToEndRoomKeyStore, + PresenceStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, @@ -259,7 +240,6 @@ class GenericWorkerSlavedStore( SlavedTransactionStore, SlavedProfileStore, SlavedClientIpStore, - SlavedPresenceStore, SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, @@ -327,10 +307,7 @@ class GenericWorkerServer(HomeServer): user_directory.register_servlets(self, resource) - # If presence is disabled, use the stub servlet that does - # not allow sending presence - if not self.config.use_presence: - PresenceStatusStubServlet(self).register(resource) + presence.register_servlets(self, resource) groups.register_servlets(self, resource) @@ -390,14 +367,16 @@ class GenericWorkerServer(HomeServer): listener_config, root_resource, self.version_string, + max_request_body_size=max_request_body_size(self.config), + reactor=self.get_reactor(), ), reactor=self.get_reactor(), ) logger.info("Synapse worker now listening on port %d", port) - def start_listening(self, listeners: Iterable[ListenerConfig]): - for listener in listeners: + def start_listening(self): + for listener in self.config.worker_listeners: if listener.type == "http": self._listen_http(listener) elif listener.type == "manhole": @@ -490,7 +469,7 @@ def start(config_options): # streams. Will no-op if no streams can be written to by this worker. hs.get_replication_streamer() - register_start(_base.start, hs, config.worker_listeners) + register_start(_base.start, hs) _base.start_worker_reactor("synapse-generic-worker", config) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8be8b520eb..8e78134bbe 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -17,7 +17,7 @@ import logging import os import sys -from typing import Iterable, Iterator +from typing import Iterator from twisted.internet import reactor from twisted.web.resource import EncodingResourceWrapper, IResource @@ -36,7 +36,13 @@ from synapse.api.urls import ( WEB_CLIENT_PREFIX, ) from synapse.app import _base -from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start +from synapse.app._base import ( + listen_ssl, + listen_tcp, + max_request_body_size, + quit_with_error, + register_start, +) from synapse.config._base import ConfigError from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig @@ -126,19 +132,21 @@ class SynapseHomeServer(HomeServer): else: root_resource = OptionsResource() - root_resource = create_resource_tree(resources, root_resource) + site = SynapseSite( + "synapse.access.%s.%s" % ("https" if tls else "http", site_tag), + site_tag, + listener_config, + create_resource_tree(resources, root_resource), + self.version_string, + max_request_body_size=max_request_body_size(self.config), + reactor=self.get_reactor(), + ) if tls: ports = listen_ssl( bind_addresses, port, - SynapseSite( - "synapse.access.https.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), + site, self.tls_server_context_factory, reactor=self.get_reactor(), ) @@ -148,13 +156,7 @@ class SynapseHomeServer(HomeServer): ports = listen_tcp( bind_addresses, port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), + site, reactor=self.get_reactor(), ) logger.info("Synapse now listening on TCP port %d", port) @@ -273,14 +275,14 @@ class SynapseHomeServer(HomeServer): return resources - def start_listening(self, listeners: Iterable[ListenerConfig]): + def start_listening(self): if self.config.redis_enabled: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client # rather than a server). self.get_tcp_replication().start_replication(self) - for listener in listeners: + for listener in self.config.server.listeners: if listener.type == "http": self._listening_services.extend( self._listener_http(self.config, listener) @@ -412,7 +414,7 @@ def setup(config_options): # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() - await _base.start(hs, config.listeners) + await _base.start(hs) hs.get_datastore().db_pool.updates.start_doing_background_updates() diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index e896fd34e2..ff9abbc232 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,21 +1,22 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( + account_validity, api, appservice, auth, captcha, cas, - consent_config, + consent, database, emailconfig, experimental, groups, - jwt_config, + jwt, key, logger, metrics, - oidc_config, + oidc, password_auth_providers, push, ratelimiting, @@ -23,9 +24,9 @@ from synapse.config import ( registration, repository, room_directory, - saml2_config, + saml2, server, - server_notices_config, + server_notices, spam_checker, sso, stats, @@ -59,15 +60,16 @@ class RootConfig: captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig + account_validity: account_validity.AccountValidityConfig metrics: metrics.MetricsConfig api: api.ApiConfig appservice: appservice.AppServiceConfig key: key.KeyConfig - saml2: saml2_config.SAML2Config + saml2: saml2.SAML2Config cas: cas.CasConfig sso: sso.SSOConfig - oidc: oidc_config.OIDCConfig - jwt: jwt_config.JWTConfig + oidc: oidc.OIDCConfig + jwt: jwt.JWTConfig auth: auth.AuthConfig email: emailconfig.EmailConfig worker: workers.WorkerConfig @@ -76,9 +78,9 @@ class RootConfig: spamchecker: spam_checker.SpamCheckerConfig groups: groups.GroupsConfig userdirectory: user_directory.UserDirectoryConfig - consent: consent_config.ConsentConfig + consent: consent.ConsentConfig stats: stats.StatsConfig - servernotices: server_notices_config.ServerNoticesConfig + servernotices: server_notices.ServerNoticesConfig roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py new file mode 100644 index 0000000000..c58a7d95a7 --- /dev/null +++ b/synapse/config/account_validity.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config._base import Config, ConfigError + + +class AccountValidityConfig(Config): + section = "account_validity" + + def read_config(self, config, **kwargs): + account_validity_config = config.get("account_validity") or {} + self.account_validity_enabled = account_validity_config.get("enabled", False) + self.account_validity_renew_by_email_enabled = ( + "renew_at" in account_validity_config + ) + + if self.account_validity_enabled: + if "period" in account_validity_config: + self.account_validity_period = self.parse_duration( + account_validity_config["period"] + ) + else: + raise ConfigError("'period' is required when using account validity") + + if "renew_at" in account_validity_config: + self.account_validity_renew_at = self.parse_duration( + account_validity_config["renew_at"] + ) + + if "renew_email_subject" in account_validity_config: + self.account_validity_renew_email_subject = account_validity_config[ + "renew_email_subject" + ] + else: + self.account_validity_renew_email_subject = "Renew your %(app)s account" + + self.account_validity_startup_job_max_delta = ( + self.account_validity_period * 10.0 / 100.0 + ) + + if self.account_validity_renew_by_email_enabled: + if not self.public_baseurl: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + # Load account validity templates. + account_validity_template_dir = account_validity_config.get("template_dir") + + account_renewed_template_filename = account_validity_config.get( + "account_renewed_html_path", "account_renewed.html" + ) + invalid_token_template_filename = account_validity_config.get( + "invalid_token_html_path", "invalid_token.html" + ) + + # Read and store template content + ( + self.account_validity_account_renewed_template, + self.account_validity_account_previously_renewed_template, + self.account_validity_invalid_token_template, + ) = self.read_templates( + [ + account_renewed_template_filename, + "account_previously_renewed.html", + invalid_token_template_filename, + ], + account_validity_template_dir, + ) + + def generate_config_section(self, **kwargs): + return """\ + ## Account Validity ## + + # Optional account validity configuration. This allows for accounts to be denied + # any request after a given period. + # + # Once this feature is enabled, Synapse will look for registered users without an + # expiration date at startup and will add one to every account it found using the + # current settings at that time. + # This means that, if a validity period is set, and Synapse is restarted (it will + # then derive an expiration date from the current validity period), and some time + # after that the validity period changes and Synapse is restarted, the users' + # expiration dates won't be updated unless their account is manually renewed. This + # date will be randomly selected within a range [now + period - d ; now + period], + # where d is equal to 10% of the validity period. + # + account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + # The currently available templates are: + # + # * account_renewed.html: Displayed to the user after they have successfully + # renewed their account. + # + # * account_previously_renewed.html: Displayed to the user if they attempt to + # renew their account with a token that is valid, but that has already + # been used. In this case the account is not renewed again. + # + # * invalid_token.html: Displayed to the user when they try to renew an account + # with an unknown or invalid renewal token. + # + # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for + # default template contents. + # + # The file name of some of these templates can be configured below for legacy + # reasons. + # + #template_dir: "res/templates" + + # A custom file name for the 'account_renewed.html' template. + # + # If not set, the file is assumed to be named "account_renewed.html". + # + #account_renewed_html_path: "account_renewed.html" + + # A custom file name for the 'invalid_token.html' template. + # + # If not set, the file is assumed to be named "invalid_token.html". + # + #invalid_token_html_path: "invalid_token.html" + """ diff --git a/synapse/config/consent_config.py b/synapse/config/consent.py index 30d07cc219..30d07cc219 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent.py diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index c587939c7a..5564d7d097 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -299,7 +299,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity.renew_by_email_enabled: + if self.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 1309535068..c23b66c88c 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -12,25 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ._base import RootConfig +from .account_validity import AccountValidityConfig from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig -from .consent_config import ConsentConfig +from .consent import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig -from .jwt_config import JWTConfig +from .jwt import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig -from .oidc_config import OIDCConfig +from .oidc import OIDCConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig @@ -39,9 +39,9 @@ from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room import RoomConfig from .room_directory import RoomDirectoryConfig -from .saml2_config import SAML2Config +from .saml2 import SAML2Config from .server import ServerConfig -from .server_notices_config import ServerNoticesConfig +from .server_notices import ServerNoticesConfig from .spam_checker import SpamCheckerConfig from .sso import SSOConfig from .stats import StatsConfig @@ -68,6 +68,7 @@ class HomeServerConfig(RootConfig): CaptchaConfig, VoipConfig, RegistrationConfig, + AccountValidityConfig, MetricsConfig, ApiConfig, AppServiceConfig, diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt.py index 9e07e73008..9e07e73008 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt.py diff --git a/synapse/config/logger.py b/synapse/config/logger.py index b174e0df6d..813076dfe2 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -31,7 +31,6 @@ from twisted.logger import ( ) import synapse -from synapse.app import _base as appbase from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter @@ -318,6 +317,8 @@ def setup_logging( # Perform one-time logging configuration. _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner) # Add a SIGHUP handler to reload the logging configuration, if one is available. + from synapse.app import _base as appbase + appbase.register_sighup(_reload_logging_config, log_config_path) # Log immediately so we can grep backwards. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc.py index 5fb94376fd..ea0abf5aa2 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc.py @@ -14,20 +14,23 @@ # limitations under the License. from collections import Counter -from typing import Iterable, List, Mapping, Optional, Tuple, Type +from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type import attr from synapse.config._util import validate_config from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import Collection, JsonDict +from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError, read_file -DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" +DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider" +# The module that JinjaOidcMappingProvider is in was renamed, we want to +# transparently handle both the same. +LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" class OIDCConfig(Config): @@ -403,6 +406,8 @@ def _parse_oidc_config_dict( """ ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + if ump_config.get("module") == LEGACY_USER_MAPPING_PROVIDER: + ump_config["module"] = DEFAULT_USER_MAPPING_PROVIDER ump_config.setdefault("config", {}) ( diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f8a2768af8..e6f52b4f40 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -12,74 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -import pkg_resources - from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool -class AccountValidityConfig(Config): - section = "accountvalidity" - - def __init__(self, config, synapse_config): - if config is None: - return - super().__init__() - self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = "renew_at" in config - - if self.enabled: - if "period" in config: - self.period = self.parse_duration(config["period"]) - else: - raise ConfigError("'period' is required when using account validity") - - if "renew_at" in config: - self.renew_at = self.parse_duration(config["renew_at"]) - - if "renew_email_subject" in config: - self.renew_email_subject = config["renew_email_subject"] - else: - self.renew_email_subject = "Renew your %(app)s account" - - self.startup_job_max_delta = self.period * 10.0 / 100.0 - - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - - template_dir = config.get("template_dir") - - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - if "account_renewed_html_path" in config: - file_path = os.path.join(template_dir, config["account_renewed_html_path"]) - - self.account_renewed_html_content = self.read_file( - file_path, "account_validity.account_renewed_html_path" - ) - else: - self.account_renewed_html_content = ( - "<html><body>Your account has been successfully renewed.</body><html>" - ) - - if "invalid_token_html_path" in config: - file_path = os.path.join(template_dir, config["invalid_token_html_path"]) - - self.invalid_token_html_content = self.read_file( - file_path, "account_validity.invalid_token_html_path" - ) - else: - self.invalid_token_html_content = ( - "<html><body>Invalid renewal token.</body><html>" - ) - - class RegistrationConfig(Config): section = "registration" @@ -92,10 +30,6 @@ class RegistrationConfig(Config): str(config["disable_registration"]) ) - self.account_validity = AccountValidityConfig( - config.get("account_validity") or {}, config - ) - self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) @@ -207,69 +141,6 @@ class RegistrationConfig(Config): # #enable_registration: false - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10%% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %%(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - #template_dir: "res/templates" - - # File within 'template_dir' giving the HTML to be displayed to the user after - # they successfully renewed their account. If not set, default text is used. - # - #account_renewed_html_path: "account_renewed.html" - - # File within 'template_dir' giving the HTML to be displayed when the user - # tries to renew an account with an invalid renewal token. If not set, - # default text is used. - # - #invalid_token_html_path: "invalid_token.html" - # Time that a user's session remains valid for, after they log in. # # Note that this is not currently compatible with guest logins. diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2.py index 55a7838b10..3d1218c8d1 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2.py @@ -25,7 +25,10 @@ from ._util import validate_config logger = logging.getLogger(__name__) -DEFAULT_USER_MAPPING_PROVIDER = ( +DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.saml.DefaultSamlMappingProvider" +# The module that DefaultSamlMappingProvider is in was renamed, we want to +# transparently handle both the same. +LEGACY_USER_MAPPING_PROVIDER = ( "synapse.handlers.saml_handler.DefaultSamlMappingProvider" ) @@ -97,6 +100,8 @@ class SAML2Config(Config): # Use the default user mapping provider if not set ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER: + ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER # Ensure a config is present ump_dict["config"] = ump_dict.get("config") or {} diff --git a/synapse/config/server.py b/synapse/config/server.py index 02b86b11a5..21ca7b33e3 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -235,7 +235,11 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) + self.public_baseurl = config.get("public_baseurl") + if self.public_baseurl is not None: + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" # Whether to enable user presence. presence_config = config.get("presence") or {} @@ -407,10 +411,6 @@ class ServerConfig(Config): config_path=("federation_ip_range_blacklist",), ) - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" - # (undocumented) option for torturing the worker-mode replication a bit, # for testing. The value defines the number of milliseconds to pause before # sending out any replication updates. diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices.py index 48bf3241b6..48bf3241b6 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices.py diff --git a/synapse/config/workers.py b/synapse/config/workers.py index b2540163d1..462630201d 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -64,6 +64,14 @@ class WriterLocations: Attributes: events: The instances that write to the event and backfill streams. typing: The instance that writes to the typing stream. + to_device: The instances that write to the to_device stream. Currently + can only be a single instance. + account_data: The instances that write to the account data streams. Currently + can only be a single instance. + receipts: The instances that write to the receipts stream. Currently + can only be a single instance. + presence: The instances that write to the presence stream. Currently + can only be a single instance. """ events = attr.ib( @@ -85,6 +93,11 @@ class WriterLocations: type=List[str], converter=_instance_to_list_converter, ) + presence = attr.ib( + default=["master"], + type=List[str], + converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -188,7 +201,14 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing", "to_device", "account_data", "receipts"): + for stream in ( + "events", + "typing", + "to_device", + "account_data", + "receipts", + "presence", + ): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -215,6 +235,11 @@ class WorkerConfig(Config): if len(self.writers.events) == 0: raise ConfigError("Must specify at least one instance to handle `events`.") + if len(self.writers.presence) != 1: + raise ConfigError( + "Must only specify one instance to handle `presence` messages." + ) + self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.writers.events ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5234e3f81e..70c556566e 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -14,14 +14,14 @@ # limitations under the License. import logging -from typing import List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import MAX_PDU_SIZE, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, @@ -205,7 +205,7 @@ def _check_size_limits(event: EventBase) -> None: too_big("type") if len(event.event_id) > 255: too_big("event_id") - if len(encode_canonical_json(event.get_pdu_json())) > 65536: + if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: too_big("event") @@ -670,7 +670,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase public_key = public_key_object["public_key"] try: for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): + for key_name in signature_block.keys(): if not key_name.startswith("ed25519:"): continue verify_key = decode_verify_key_bytes( @@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase return False -def get_public_keys(invite_event): +def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]: public_keys = [] if "public_key" in invite_event.content: o = {"public_key": invite_event.content["public_key"]} diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index c727b48c1e..d5fa195094 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,12 +15,12 @@ import inspect import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import Collection +from synapse.types import RoomAlias from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: @@ -114,7 +114,9 @@ class SpamChecker: return True - async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: + async def user_may_create_room_alias( + self, userid: str, room_alias: RoomAlias + ) -> bool: """Checks if a given user may create a room alias If this method returns false, the association request will be rejected. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f93335edaa..a5b6a61195 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -451,6 +451,28 @@ class FederationClient(FederationBase): return signed_auth + def _is_unknown_endpoint( + self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None + ) -> bool: + """ + Returns true if the response was due to an endpoint being unimplemented. + + Args: + e: The error response received from the remote server. + synapse_error: The above error converted to a SynapseError. This is + automatically generated if not provided. + + """ + if synapse_error is None: + synapse_error = e.to_synapse_error() + # There is no good way to detect an "unknown" endpoint. + # + # Dendrite returns a 404 (with no body); synapse returns a 400 + # with M_UNRECOGNISED. + return e.code == 404 or ( + e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED + ) + async def _try_destination_list( self, description: str, @@ -468,9 +490,9 @@ class FederationClient(FederationBase): callback: Function to run for each server. Passed a single argument: the server_name to try. - If the callback raises a CodeMessageException with a 300/400 code, - attempts to perform the operation stop immediately and the exception is - reraised. + If the callback raises a CodeMessageException with a 300/400 code or + an UnsupportedRoomVersionError, attempts to perform the operation + stop immediately and the exception is reraised. Otherwise, if the callback raises an Exception the error is logged and the next server tried. Normally the stacktrace is logged but this is @@ -492,8 +514,7 @@ class FederationClient(FederationBase): continue try: - res = await callback(destination) - return res + return await callback(destination) except InvalidResponseError as e: logger.warning("Failed to %s via %s: %s", description, destination, e) except UnsupportedRoomVersionError: @@ -502,17 +523,15 @@ class FederationClient(FederationBase): synapse_error = e.to_synapse_error() failover = False + # Failover on an internal server error, or if the destination + # doesn't implemented the endpoint for some reason. if 500 <= e.code < 600: failover = True - elif failover_on_unknown_endpoint: - # there is no good way to detect an "unknown" endpoint. Dendrite - # returns a 404 (with no body); synapse returns a 400 - # with M_UNRECOGNISED. - if e.code == 404 or ( - e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED - ): - failover = True + elif failover_on_unknown_endpoint and self._is_unknown_endpoint( + e, synapse_error + ): + failover = True if not failover: raise synapse_error from e @@ -570,9 +589,8 @@ class FederationClient(FederationBase): UnsupportedRoomVersionError: if remote responds with a room version we don't understand. - SynapseError: if the chosen remote server returns a 300/400 code. - - RuntimeError: if no servers were reachable. + SynapseError: if the chosen remote server returns a 300/400 code, or + no servers successfully handle the request. """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: @@ -642,9 +660,8 @@ class FederationClient(FederationBase): ``auth_chain``. Raises: - SynapseError: if the chosen remote server returns a 300/400 code. - - RuntimeError: if no servers were reachable. + SynapseError: if the chosen remote server returns a 300/400 code, or + no servers successfully handle the request. """ async def send_request(destination) -> Dict[str, Any]: @@ -673,7 +690,7 @@ class FederationClient(FederationBase): if create_event is None: # If the state doesn't have a create event then the room is # invalid, and it would fail auth checks anyway. - raise SynapseError(400, "No create event in state") + raise InvalidResponseError("No create event in state") # the room version should be sane. create_room_version = create_event.content.get( @@ -746,16 +763,11 @@ class FederationClient(FederationBase): content=pdu.get_pdu_json(time_now), ) except HttpResponseException as e: - if e.code in [400, 404]: - err = e.to_synapse_error() - - # If we receive an error response that isn't a generic error, or an - # unrecognised endpoint error, we assume that the remote understands - # the v2 invite API and this is a legitimate error. - if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: - raise err - else: - raise e.to_synapse_error() + # If an error is received that is due to an unrecognised endpoint, + # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # and raise. + if not self._is_unknown_endpoint(e): + raise logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") @@ -802,6 +814,11 @@ class FederationClient(FederationBase): Returns: The event as a dict as returned by the remote server + + Raises: + SynapseError: if the remote server returns an error or if the server + only supports the v1 endpoint and a room version other than "1" + or "2" is requested. """ time_now = self._clock.time_msec() @@ -817,28 +834,19 @@ class FederationClient(FederationBase): }, ) except HttpResponseException as e: - if e.code in [400, 404]: - err = e.to_synapse_error() - - # If we receive an error response that isn't a generic error, we - # assume that the remote understands the v2 invite API and this - # is a legitimate error. - if err.errcode != Codes.UNKNOWN: - raise err - - # Otherwise, we assume that the remote server doesn't understand - # the v2 invite API. That's ok provided the room uses old-style event - # IDs. + # If an error is received that is due to an unrecognised endpoint, + # fallback to the v1 endpoint if the room uses old-style event IDs. + # Otherwise consider it a legitmate error and raise. + err = e.to_synapse_error() + if self._is_unknown_endpoint(e, err): if room_version.event_format != EventFormatVersions.V1: raise SynapseError( 400, "User's homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - elif e.code in (403, 429): - raise e.to_synapse_error() else: - raise + raise err # Didn't work, try v1 API. # Note the v1 API returns a tuple of `(200, content)` @@ -865,9 +873,8 @@ class FederationClient(FederationBase): pdu: event to be sent Raises: - SynapseError if the chosen remote server returns a 300/400 code. - - RuntimeError if no servers were reachable. + SynapseError: if the chosen remote server returns a 300/400 code, or + no servers successfully handle the request. """ async def send_request(destination: str) -> None: @@ -889,16 +896,11 @@ class FederationClient(FederationBase): content=pdu.get_pdu_json(time_now), ) except HttpResponseException as e: - if e.code in [400, 404]: - err = e.to_synapse_error() - - # If we receive an error response that isn't a generic error, or an - # unrecognised endpoint error, we assume that the remote understands - # the v2 invite API and this is a legitimate error. - if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: - raise err - else: - raise e.to_synapse_error() + # If an error is received that is due to an unrecognised endpoint, + # fallback to the v1 endpoint. Otherwise consider it a legitmate error + # and raise. + if not self._is_unknown_endpoint(e): + raise logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index e3f0bc2471..65d76ea974 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -76,9 +76,6 @@ class FederationRemoteSendQueue(AbstractFederationSender): # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] - # Stream position -> list[user_id] - self.presence_changed = SortedDict() # type: SortedDict[int, List[str]] - # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) @@ -96,7 +93,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.edus = SortedDict() # type: SortedDict[int, Edu] - # stream ID for the next entry into presence_changed/keyed_edu_changed/edus. + # stream ID for the next entry into keyed_edu_changed/edus. self.pos = 1 # map from stream ID to the time that stream entry was generated, so that we @@ -117,7 +114,6 @@ class FederationRemoteSendQueue(AbstractFederationSender): for queue_name in [ "presence_map", - "presence_changed", "keyed_edu", "keyed_edu_changed", "edus", @@ -155,23 +151,12 @@ class FederationRemoteSendQueue(AbstractFederationSender): """Clear all the queues from before a given position""" with Measure(self.clock, "send_queue._clear"): # Delete things out of presence maps - keys = self.presence_changed.keys() - i = self.presence_changed.bisect_left(position_to_delete) - for key in keys[:i]: - del self.presence_changed[key] - - user_ids = { - user_id for uids in self.presence_changed.values() for user_id in uids - } - keys = self.presence_destinations.keys() i = self.presence_destinations.bisect_left(position_to_delete) for key in keys[:i]: del self.presence_destinations[key] - user_ids.update( - user_id for user_id, _ in self.presence_destinations.values() - ) + user_ids = {user_id for user_id, _ in self.presence_destinations.values()} to_del = [ user_id for user_id in self.presence_map if user_id not in user_ids @@ -244,23 +229,6 @@ class FederationRemoteSendQueue(AbstractFederationSender): """ # nothing to do here: the replication listener will handle it. - def send_presence(self, states: List[UserPresenceState]) -> None: - """As per FederationSender - - Args: - states - """ - pos = self._next_pos() - - # We only want to send presence for our own users, so lets always just - # filter here just in case. - local_states = [s for s in states if self.is_mine_id(s.user_id)] - - self.presence_map.update({state.user_id: state for state in local_states}) - self.presence_changed[pos] = [state.user_id for state in local_states] - - self.notifier.on_new_replication_data() - def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: @@ -325,18 +293,6 @@ class FederationRemoteSendQueue(AbstractFederationSender): # of the federation stream. rows = [] # type: List[Tuple[int, BaseFederationRow]] - # Fetch changed presence - i = self.presence_changed.bisect_right(from_token) - j = self.presence_changed.bisect_right(to_token) + 1 - dest_user_ids = [ - (pos, user_id) - for pos, user_id_list in self.presence_changed.items()[i:j] - for user_id in user_id_list - ] - - for (key, user_id) in dest_user_ids: - rows.append((key, PresenceRow(state=self.presence_map[user_id]))) - # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) j = self.presence_destinations.bisect_right(to_token) + 1 @@ -427,22 +383,6 @@ class BaseFederationRow: raise NotImplementedError() -class PresenceRow( - BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState -): - TypeId = "p" - - @staticmethod - def from_data(data): - return PresenceRow(state=UserPresenceState.from_dict(data)) - - def to_data(self): - return self.state.as_dict() - - def add_to_buffer(self, buff): - buff.presence.append(self.state) - - class PresenceDestinationsRow( BaseFederationRow, namedtuple( @@ -506,7 +446,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu _rowtypes = ( - PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, @@ -518,7 +457,6 @@ TypeToRow = {Row.TypeId: Row for Row in _rowtypes} ParsedFederationStreamData = namedtuple( "ParsedFederationStreamData", ( - "presence", # list(UserPresenceState) "presence_destinations", # list of tuples of UserPresenceState and destinations "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] @@ -543,7 +481,6 @@ def process_rows_for_federation( # them into the appropriate collection and then send them off. buff = ParsedFederationStreamData( - presence=[], presence_destinations=[], keyed_edus={}, edus={}, @@ -559,18 +496,15 @@ def process_rows_for_federation( parsed_row = RowType.from_data(row.data) parsed_row.add_to_buffer(buff) - if buff.presence: - transaction_queue.send_presence(buff.presence) - for state, destinations in buff.presence_destinations: transaction_queue.send_presence_to_destinations( states=[state], destinations=destinations ) - for destination, edu_map in buff.keyed_edus.items(): + for edu_map in buff.keyed_edus.values(): for key, edu in edu_map.items(): transaction_queue.send_edu(edu, key) - for destination, edu_list in buff.edus.items(): + for edu_list in buff.edus.values(): for edu in edu_list: transaction_queue.send_edu(edu, None) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 952ad39f8c..deb40f4610 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -18,14 +18,15 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, from prometheus_client import Counter +from twisted.internet import defer + import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu -from synapse.handlers.presence import get_interested_remotes -from synapse.logging.context import preserve_fn +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( LaterGauge, event_processing_loop_counter, @@ -33,8 +34,8 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import Collection, JsonDict, ReadReceipt, RoomStreamToken -from synapse.util.metrics import Measure, measure_func +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.events.presence_router import PresenceRouter @@ -80,15 +81,6 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_presence(self, states: List[UserPresenceState]) -> None: - """Send the new presence states to the appropriate destinations. - - This actually queues up the presence states ready for sending and - triggers a background task to process them and send out the transactions. - """ - raise NotImplementedError() - - @abc.abstractmethod def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: @@ -176,11 +168,6 @@ class FederationSender(AbstractFederationSender): ), ) - # Map of user_id -> UserPresenceState for all the pending presence - # to be sent out by user_id. Entries here get processed and put in - # pending_presence_by_dest - self.pending_presence = {} # type: Dict[str, UserPresenceState] - LaterGauge( "synapse_federation_transaction_queue_pending_pdus", "", @@ -201,8 +188,6 @@ class FederationSender(AbstractFederationSender): self._is_processing = False self._last_poked_id = -1 - self._processing_pending_presence = False - # map from room_id to a set of PerDestinationQueues which we believe are # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, @@ -270,27 +255,15 @@ class FederationSender(AbstractFederationSender): if not events and next_token >= self._last_poked_id: break - async def get_destinations_for_event( - event: EventBase, - ) -> Collection[str]: - """Computes the destinations to which this event must be sent. - - This returns an empty tuple when there are no destinations to send to, - or if this event is not from this homeserver and it is not sending - it on behalf of another server. - - Will also filter out destinations which this sender is not responsible for, - if multiple federation senders exist. - """ - + async def handle_event(event: EventBase) -> None: # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) if not is_mine and send_on_behalf_of is None: - return () + return if not event.internal_metadata.should_proactively_send(): - return () + return destinations = None # type: Optional[Set[str]] if not event.prev_event_ids(): @@ -325,7 +298,7 @@ class FederationSender(AbstractFederationSender): "Failed to calculate hosts in room for event: %s", event.event_id, ) - return () + return destinations = { d @@ -335,15 +308,17 @@ class FederationSender(AbstractFederationSender): ) } - destinations.discard(self.server_name) - if send_on_behalf_of is not None: # If we are sending the event on behalf of another server # then it already has the event and there is no reason to # send the event to it. destinations.discard(send_on_behalf_of) + logger.debug("Sending %s to %r", event, destinations) + if destinations: + await self._send_pdu(event, destinations) + now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -351,29 +326,24 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).observe((now - ts) / 1000) - return destinations - return () - - async def get_federatable_events_and_destinations( - events: Iterable[EventBase], - ) -> List[Tuple[EventBase, Collection[str]]]: - with Measure(self.clock, "get_destinations_for_events"): - # Fetch federation destinations per event, - # skip if get_destinations_for_event returns an empty collection, - # return list of event->destinations pairs. - return [ - (event, dests) - for (event, dests) in [ - (event, await get_destinations_for_event(event)) - for event in events - ] - if dests - ] - - events_and_dests = await get_federatable_events_and_destinations(events) - - # Send corresponding events to each destination queue - await self._distribute_events(events_and_dests) + async def handle_room_events(events: Iterable[EventBase]) -> None: + with Measure(self.clock, "handle_room_events"): + for event in events: + await handle_event(event) + + events_by_room = {} # type: Dict[str, List[EventBase]] + for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(handle_room_events, evs) + for evs in events_by_room.values() + ], + consumeErrors=True, + ) + ) await self.store.update_federation_out_pos("events", next_token) @@ -391,7 +361,7 @@ class FederationSender(AbstractFederationSender): events_processed_counter.inc(len(events)) event_processing_loop_room_count.labels("federation_sender").inc( - len({event.room_id for event in events}) + len(events_by_room) ) event_processing_loop_counter.labels("federation_sender").inc() @@ -403,53 +373,34 @@ class FederationSender(AbstractFederationSender): finally: self._is_processing = False - async def _distribute_events( - self, - events_and_dests: Iterable[Tuple[EventBase, Collection[str]]], - ) -> None: - """Distribute events to the respective per_destination queues. + async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. - Also persists last-seen per-room stream_ordering to 'destination_rooms'. + destinations = set(destinations) + destinations.discard(self.server_name) + logger.debug("Sending to: %s", str(destinations)) - Args: - events_and_dests: A list of tuples, which are (event: EventBase, destinations: Collection[str]). - Every event is paired with its intended destinations (in federation). - """ - # Tuples of room_id + destination to their max-seen stream_ordering - room_with_dest_stream_ordering = {} # type: Dict[Tuple[str, str], int] - - # List of events to send to each destination - events_by_dest = {} # type: Dict[str, List[EventBase]] - - # For each event-destinations pair... - for event, destinations in events_and_dests: - - # (we got this from the database, it's filled) - assert event.internal_metadata.stream_ordering - - sent_pdus_destination_dist_total.inc(len(destinations)) - sent_pdus_destination_dist_count.inc() + if not destinations: + return - # ...iterate over those destinations.. - for destination in destinations: - # ...update their stream-ordering... - room_with_dest_stream_ordering[(event.room_id, destination)] = max( - event.internal_metadata.stream_ordering, - room_with_dest_stream_ordering.get((event.room_id, destination), 0), - ) + sent_pdus_destination_dist_total.inc(len(destinations)) + sent_pdus_destination_dist_count.inc() - # ...and add the event to each destination queue. - events_by_dest.setdefault(destination, []).append(event) + assert pdu.internal_metadata.stream_ordering - # Bulk-store destination_rooms stream_ids - await self.store.bulk_store_destination_rooms_entries( - room_with_dest_stream_ordering + # track the fact that we have a PDU for these destinations, + # to allow us to perform catch-up later on if the remote is unreachable + # for a while. + await self.store.store_destination_rooms_entries( + destinations, + pdu.room_id, + pdu.internal_metadata.stream_ordering, ) - for destination, pdus in events_by_dest.items(): - logger.debug("Sending %d pdus to %s", len(pdus), destination) - - self._get_per_destination_queue(destination).send_pdus(pdus) + for destination in destinations: + self._get_per_destination_queue(destination).send_pdu(pdu) async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room @@ -546,48 +497,6 @@ class FederationSender(AbstractFederationSender): for queue in queues: queue.flush_read_receipts_for_room(room_id) - @preserve_fn # the caller should not yield on this - async def send_presence(self, states: List[UserPresenceState]) -> None: - """Send the new presence states to the appropriate destinations. - - This actually queues up the presence states ready for sending and - triggers a background task to process them and send out the transactions. - """ - if not self.hs.config.use_presence: - # No-op if presence is disabled. - return - - # First we queue up the new presence by user ID, so multiple presence - # updates in quick succession are correctly handled. - # We only want to send presence for our own users, so lets always just - # filter here just in case. - self.pending_presence.update( - {state.user_id: state for state in states if self.is_mine_id(state.user_id)} - ) - - # We then handle the new pending presence in batches, first figuring - # out the destinations we need to send each state to and then poking it - # to attempt a new transaction. We linearize this so that we don't - # accidentally mess up the ordering and send multiple presence updates - # in the wrong order - if self._processing_pending_presence: - return - - self._processing_pending_presence = True - try: - while True: - states_map = self.pending_presence - self.pending_presence = {} - - if not states_map: - break - - await self._process_presence_inner(list(states_map.values())) - except Exception: - logger.exception("Error sending presence states to servers") - finally: - self._processing_pending_presence = False - def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: @@ -599,6 +508,10 @@ class FederationSender(AbstractFederationSender): # No-op if presence is disabled. return + # Ensure we only send out presence states for local users. + for state in states: + assert self.is_mine_id(state.user_id) + for destination in destinations: if destination == self.server_name: continue @@ -608,40 +521,6 @@ class FederationSender(AbstractFederationSender): continue self._get_per_destination_queue(destination).send_presence(states) - @measure_func("txnqueue._process_presence") - async def _process_presence_inner(self, states: List[UserPresenceState]) -> None: - """Given a list of states populate self.pending_presence_by_dest and - poke to send a new transaction to each destination - """ - # We pull the presence router here instead of __init__ - # to prevent a dependency cycle: - # - # AuthHandler -> Notifier -> FederationSender - # -> PresenceRouter -> ModuleApi -> AuthHandler - if self._presence_router is None: - self._presence_router = self.hs.get_presence_router() - - assert self._presence_router is not None - - hosts_and_states = await get_interested_remotes( - self.store, - self._presence_router, - states, - self.state, - ) - - for destinations, states in hosts_and_states: - for destination in destinations: - if destination == self.server_name: - continue - - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - continue - - self._get_per_destination_queue(destination).send_presence(states) - def build_and_send_edu( self, destination: str, diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 3bb66bce32..3b053ebcfb 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -154,22 +154,19 @@ class PerDestinationQueue: + len(self._pending_edus_keyed) ) - def send_pdus(self, pdus: Iterable[EventBase]) -> None: - """Add PDUs to the queue, and start the transmission loop if necessary + def send_pdu(self, pdu: EventBase) -> None: + """Add a PDU to the queue, and start the transmission loop if necessary Args: - pdus: pdus to send + pdu: pdu to send """ if not self._catching_up or self._last_successful_stream_ordering is None: # only enqueue the PDU if we are not catching up (False) or do not # yet know if we have anything to catch up (None) - self._pending_pdus.extend(pdus) + self._pending_pdus.append(pdu) else: - self._catchup_last_skipped = max( - pdu.internal_metadata.stream_ordering - for pdu in pdus - if pdu.internal_metadata.stream_ordering is not None - ) + assert pdu.internal_metadata.stream_ordering + self._catchup_last_skipped = pdu.internal_metadata.stream_ordering self.attempt_new_transaction() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 66ce7e8b83..5b927f10b3 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -17,7 +17,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable @@ -39,28 +39,44 @@ class AccountValidityHandler: self.sendmail = self.hs.get_sendmail() self.clock = self.hs.get_clock() - self._account_validity = self.hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) + self._account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) + + self._account_validity_period = None + if self._account_validity_enabled: + self._account_validity_period = ( + hs.config.account_validity.account_validity_period + ) if ( - self._account_validity.enabled - and self._account_validity.renew_by_email_enabled + self._account_validity_enabled + and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = self.config.account_validity_template_html - self._template_text = self.config.account_validity_template_text + self._template_html = ( + hs.config.account_validity.account_validity_template_html + ) + self._template_text = ( + hs.config.account_validity.account_validity_template_text + ) + account_validity_renew_email_subject = ( + hs.config.account_validity.account_validity_renew_email_subject + ) try: - app_name = self.hs.config.email_app_name + app_name = hs.config.email_app_name - self._subject = self._account_validity.renew_email_subject % { - "app": app_name - } + self._subject = account_validity_renew_email_subject % {"app": app_name} - self._from_string = self.hs.config.email_notif_from % {"app": app_name} + self._from_string = hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. - self._subject = self._account_validity.renew_email_subject - self._from_string = self.hs.config.email_notif_from + self._subject = account_validity_renew_email_subject + self._from_string = hs.config.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] @@ -220,50 +236,87 @@ class AccountValidityHandler: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - async def renew_account(self, renewal_token: str) -> bool: + async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. + If it turns out that the token is valid but has already been used, then the + token is considered stale. A token is stale if the 'token_used_ts_ms' db column + is non-null. + Args: renewal_token: Token sent with the renewal request. Returns: - Whether the provided token is valid. + A tuple containing: + * A bool representing whether the token is valid and unused. + * A bool which is `True` if the token is valid, but stale. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. """ try: - user_id = await self.store.get_user_from_renewal_token(renewal_token) + ( + user_id, + current_expiration_ts, + token_used_ts, + ) = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - return False + return False, False, 0 + + # Check whether this token has already been used. + if token_used_ts: + logger.info( + "User '%s' attempted to use previously used token '%s' to renew account", + user_id, + renewal_token, + ) + return False, True, current_expiration_ts logger.debug("Renewing an account for user %s", user_id) - await self.renew_account_for_user(user_id) - return True + # Renew the account. Pass the renewal_token here so that it is not cleared. + # We want to keep the token around in case the user attempts to renew their + # account with the same token twice (clicking the email link twice). + # + # In that case, the token will be accepted, but the account's expiration ts + # will remain unchanged. + new_expiration_ts = await self.renew_account_for_user( + user_id, renewal_token=renewal_token + ) + + return True, False, new_expiration_ts async def renew_account_for_user( self, user_id: str, expiration_ts: Optional[int] = None, email_sent: bool = False, + renewal_token: Optional[str] = None, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token: Token sent with the renewal request. + user_id: The ID of the user to renew. expiration_ts: New expiration date. Defaults to now + validity period. - email_sen: Whether an email has been sent for this validity period. - Defaults to False. + email_sent: Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. The user's token + will be cleared if this is None. Returns: New expiration date for this account, as a timestamp in milliseconds since epoch. """ + now = self.clock.time_msec() if expiration_ts is None: - expiration_ts = self.clock.time_msec() + self._account_validity.period + expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( - user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent + user_id=user_id, + expiration_ts=expiration_ts, + email_sent=email_sent, + renewal_token=renewal_token, + token_used_ts=now, ) return expiration_ts diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index d7bc4e23ed..177310f0be 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union from prometheus_client import Counter @@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure if TYPE_CHECKING: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b8a37b6477..36f2450e2e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1248,7 +1248,7 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - for token, token_id, device_id in tokens_and_devices: + for token, _, device_id in tokens_and_devices: await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas.py index 7346ccfe93..7346ccfe93 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas.py diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 3f6f9f7f3d..45d2404dde 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -49,7 +49,9 @@ class DeactivateAccountHandler(BaseHandler): if hs.config.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) - self._account_validity_enabled = hs.config.account_validity.enabled + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) async def deactivate_account( self, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d75edb184b..95bdc5902a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -28,7 +28,6 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( - Collection, JsonDict, StreamToken, UserID, @@ -156,8 +155,7 @@ class DeviceWorkerHandler(BaseHandler): # The user may have left the room # TODO: Check if they actually did or if we were just invited. if room_id not in room_ids: - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_left.add(state_key) @@ -179,8 +177,7 @@ class DeviceWorkerHandler(BaseHandler): log_kv( {"event": "encountered empty previous state", "room_id": room_id} ) - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_changed.add(state_key) @@ -198,8 +195,7 @@ class DeviceWorkerHandler(BaseHandler): for state_dict in prev_state_ids.values(): member_event = state_dict.get((EventTypes.Member, user_id), None) if not member_event or member_event != current_member_id: - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_changed.add(state_key) @@ -714,7 +710,7 @@ class DeviceListUpdater: # This can happen since we batch updates return - for device_id, stream_id, prev_ids, content in pending_updates: + for device_id, stream_id, prev_ids, _ in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", user_id, @@ -740,7 +736,7 @@ class DeviceListUpdater: else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) - for device_id, stream_id, prev_ids, content in pending_updates: + for device_id, stream_id, _, content in pending_updates: await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) @@ -929,6 +925,10 @@ class DeviceListUpdater: else: cached_devices = await self.store.get_cached_devices_for_user(user_id) if cached_devices == {d["device_id"]: d for d in devices}: + logging.info( + "Skipping device list resync for %s, as our cache matches already", + user_id, + ) devices = [] ignore_devices = True @@ -944,6 +944,9 @@ class DeviceListUpdater: await self.store.update_remote_device_list_cache( user_id, devices, stream_id ) + # mark the cache as valid, whether or not we actually processed any device + # list updates. + await self.store.mark_remote_user_device_cache_as_valid(user_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 90932316f3..de1b14cde3 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -14,7 +14,7 @@ import logging import string -from typing import Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( @@ -27,15 +27,19 @@ from synapse.api.errors import ( SynapseError, ) from synapse.appservice import ApplicationService -from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.state = hs.get_state_handler() @@ -60,7 +64,7 @@ class DirectoryHandler(BaseHandler): room_id: str, servers: Optional[Iterable[str]] = None, creator: Optional[str] = None, - ): + ) -> None: # general association creation for both human users and app services for wchar in string.whitespace: @@ -104,8 +108,9 @@ class DirectoryHandler(BaseHandler): """ user_id = requester.user.to_string() + room_alias_str = room_alias.to_string() - if len(room_alias.to_string()) > MAX_ALIAS_LENGTH: + if len(room_alias_str) > MAX_ALIAS_LENGTH: raise SynapseError( 400, "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH, @@ -114,7 +119,7 @@ class DirectoryHandler(BaseHandler): service = requester.app_service if service: - if not service.is_interested_in_alias(room_alias.to_string()): + if not service.is_interested_in_alias(room_alias_str): raise SynapseError( 400, "This application service has not reserved this kind of alias.", @@ -138,7 +143,7 @@ class DirectoryHandler(BaseHandler): raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( - user_id, room_id, room_alias.to_string() + user_id, room_id, room_alias_str ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages @@ -211,7 +216,7 @@ class DirectoryHandler(BaseHandler): async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias - ): + ) -> None: if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -220,7 +225,7 @@ class DirectoryHandler(BaseHandler): ) await self._delete_association(room_alias) - async def _delete_association(self, room_alias: RoomAlias): + async def _delete_association(self, room_alias: RoomAlias) -> str: if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -228,17 +233,19 @@ class DirectoryHandler(BaseHandler): return room_id - async def get_association(self, room_alias: RoomAlias): + async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result = await self.get_association_from_room_alias(room_alias) + result = await self.get_association_from_room_alias( + room_alias + ) # type: Optional[RoomAliasMapping] if result: room_id = result.room_id servers = result.servers else: try: - result = await self.federation.make_query( + fed_result = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, @@ -248,13 +255,13 @@ class DirectoryHandler(BaseHandler): except CodeMessageException as e: logging.warning("Error retrieving alias") if e.code == 404: - result = None + fed_result = None else: raise - if result and "room_id" in result and "servers" in result: - room_id = result["room_id"] - servers = result["servers"] + if fed_result and "room_id" in fed_result and "servers" in fed_result: + room_id = fed_result["room_id"] + servers = fed_result["servers"] if not room_id: raise SynapseError( @@ -275,7 +282,7 @@ class DirectoryHandler(BaseHandler): return {"room_id": room_id, "servers": servers} - async def on_directory_query(self, args): + async def on_directory_query(self, args: JsonDict) -> JsonDict: room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room Alias is not hosted on this homeserver") @@ -293,7 +300,7 @@ class DirectoryHandler(BaseHandler): async def _update_canonical_alias( self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias - ): + ) -> None: """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. @@ -344,7 +351,9 @@ class DirectoryHandler(BaseHandler): ratelimit=False, ) - async def get_association_from_room_alias(self, room_alias: RoomAlias): + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: result = await self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -372,7 +381,7 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return True - async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool: """Determine whether a user can delete an alias. One of the following must be true: @@ -394,14 +403,13 @@ class DirectoryHandler(BaseHandler): if not room_id: return False - res = await self.auth.check_can_change_room_list( + return await self.auth.check_can_change_room_list( room_id, UserID.from_string(user_id) ) - return res async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str - ): + ) -> None: """Edit the entry of the room in the published room list. requester @@ -469,7 +477,7 @@ class DirectoryHandler(BaseHandler): async def edit_published_appservice_room_list( self, appservice_id: str, network_id: str, room_id: str, visibility: str - ): + ) -> None: """Add or remove a room from the appservice/network specific public room list. @@ -499,5 +507,4 @@ class DirectoryHandler(BaseHandler): room_id, requester.user.to_string() ) - aliases = await self.store.get_aliases_for_room(room_id) - return aliases + return await self.store.get_aliases_for_room(room_id) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py new file mode 100644 index 0000000000..eff639f407 --- /dev/null +++ b/synapse/handlers/event_auth.py @@ -0,0 +1,86 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.api.constants import EventTypes, JoinRules +from synapse.api.room_versions import RoomVersion +from synapse.types import StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class EventAuthHandler: + """ + This class contains methods for authenticating events added to room graphs. + """ + + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + + async def can_join_without_invite( + self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str + ) -> bool: + """ + Check whether a user can join a room without an invite. + + When joining a room with restricted joined rules (as defined in MSC3083), + the membership of spaces must be checked during join. + + Args: + state_ids: The state of the room as it currently is. + room_version: The room version of the room being joined. + user_id: The user joining the room. + + Returns: + True if the user can join the room, false otherwise. + """ + # This only applies to room versions which support the new join rule. + if not room_version.msc3083_join_rules: + return True + + # If there's no join rule, then it defaults to invite (so this doesn't apply). + join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + if not join_rules_event_id: + return True + + # If the join rule is not restricted, this doesn't apply. + join_rules_event = await self._store.get_event(join_rules_event_id) + if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: + return True + + # If allowed is of the wrong form, then only allow invited users. + allowed_spaces = join_rules_event.content.get("allow", []) + if not isinstance(allowed_spaces, list): + return False + + # Get the list of joined rooms and see if there's an overlap. + joined_rooms = await self._store.get_rooms_for_user(user_id) + + # Pull out the other room IDs, invalid data gets filtered. + for space in allowed_spaces: + if not isinstance(space, dict): + continue + + space_id = space.get("space") + if not isinstance(space_id, str): + continue + + # The user was joined to one of the spaces specified, they can join + # this room! + if space_id in joined_rooms: + return True + + # The user was not in any of the required spaces. + return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4b3730aa3b..9d867aaf4d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -146,6 +146,7 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config @@ -1673,8 +1674,40 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin + # Calculate the event context. context = await self.state_handler.compute_event_context(event) - context = await self._auth_and_persist_event(origin, event, context) + + # Get the state before the new event. + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + newly_joined = True + user_is_invited = False + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + user_is_invited = prev_member_event.membership == Membership.INVITE + + # If the member is not already in the room, and not invited, check if + # they should be allowed access via membership in a space. + if ( + newly_joined + and not user_is_invited + and not await self._event_auth_handler.can_join_without_invite( + prev_state_ids, + event.room_version, + user_id, + ) + ): + raise AuthError( + 403, + "You do not belong to any of the required spaces to join this room.", + ) + + # Persist the event. + await self._auth_and_persist_event(origin, event, context) logger.debug( "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", @@ -1682,8 +1715,6 @@ class FederationHandler(BaseHandler): event.signatures, ) - prev_state_ids = await context.get_prev_state_ids() - state_ids = list(prev_state_ids.values()) auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) @@ -2006,7 +2037,7 @@ class FederationHandler(BaseHandler): state: Optional[Iterable[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None, backfilled: bool = False, - ) -> EventContext: + ) -> None: """ Process an event by performing auth checks and then persisting to the database. @@ -2028,9 +2059,6 @@ class FederationHandler(BaseHandler): event is an outlier), may be the auth events claimed by the remote server. backfilled: True if the event was backfilled. - - Returns: - The event context. """ context = await self._check_event_auth( origin, @@ -2060,8 +2088,6 @@ class FederationHandler(BaseHandler): ) raise - return context - async def _auth_and_persist_events( self, origin: str, @@ -2956,7 +2982,7 @@ class FederationHandler(BaseHandler): try: # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): + for key_name in signature_block.keys(): if not key_name.startswith("ed25519:"): continue diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 87a8b89237..33d16fbf9c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -15,10 +15,9 @@ # limitations under the License. """Utilities for interacting with Identity Servers""" - import logging import urllib.parse -from typing import Awaitable, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple from synapse.api.errors import ( CodeMessageException, @@ -34,17 +33,24 @@ from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 -from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.stringutils import ( + assert_valid_client_secret, + random_string, + valid_id_server_location, +) from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) id_server_scheme = "https://" class IdentityHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # An HTTP client for contacting trusted URLs. @@ -77,7 +83,7 @@ class IdentityHandler(BaseHandler): request: SynapseRequest, medium: str, address: str, - ): + ) -> None: """Used to ratelimit requests to `/requestToken` by IP and address. Args: @@ -172,6 +178,11 @@ class IdentityHandler(BaseHandler): server with, if necessary. Required if use_v2 is true use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True + Raises: + SynapseError: On any of the following conditions + - the supplied id_server is not a valid identity server name + - we failed to contact the supplied identity server + Returns: The response from the identity server """ @@ -181,6 +192,12 @@ class IdentityHandler(BaseHandler): if id_access_token is None: use_v2 = False + if not valid_id_server_location(id_server): + raise SynapseError( + 400, + "id_server must be a valid hostname with optional port and path components", + ) + # Decide which API endpoint URLs to use headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} @@ -269,12 +286,21 @@ class IdentityHandler(BaseHandler): id_server: Identity server to unbind from Raises: - SynapseError: If we failed to contact the identity server + SynapseError: On any of the following conditions + - the supplied id_server is not a valid identity server name + - we failed to contact the supplied identity server Returns: True on success, otherwise False if the identity server doesn't support unbinding """ + + if not valid_id_server_location(id_server): + raise SynapseError( + 400, + "id_server must be a valid hostname with optional port and path components", + ) + url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ec8eb21674..49f8aa25ea 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json @@ -66,7 +66,7 @@ logger = logging.getLogger(__name__) class MessageHandler: """Contains some read only APIs to get state about a room""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -91,7 +91,7 @@ class MessageHandler: room_id: str, event_type: str, state_key: str, - ) -> dict: + ) -> Optional[EventBase]: """Get data from a room. Args: @@ -115,6 +115,10 @@ class MessageHandler: data = await self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) + # If the membership is not JOIN, then the event ID should exist. + assert ( + membership_event_id is not None + ), "check_user_in_room_or_world_readable returned invalid data" room_state = await self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) @@ -186,10 +190,12 @@ class MessageHandler: event = last_events[0] if visible_events: - room_state = await self.state_store.get_state_for_events( + room_state_events = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) - room_state = room_state[event.event_id] + room_state = room_state_events[ + event.event_id + ] # type: Mapping[Any, EventBase] else: raise AuthError( 403, @@ -210,10 +216,14 @@ class MessageHandler: ) room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = await self.state_store.get_state_for_events( + # If the membership is not JOIN, then the event ID should exist. + assert ( + membership_event_id is not None + ), "check_user_in_room_or_world_readable returned invalid data" + room_state_events = await self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) - room_state = room_state[membership_event_id] + room_state = room_state_events[membership_event_id] now = self.clock.time_msec() events = await self._event_serializer.serialize_events( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc.py index b156196a70..ee6e41c0e4 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc.py @@ -15,7 +15,7 @@ import inspect import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union -from urllib.parse import urlencode +from urllib.parse import urlencode, urlparse import attr import pymacaroons @@ -37,10 +37,7 @@ from twisted.web.client import readBody from twisted.web.http_headers import Headers from synapse.config import ConfigError -from synapse.config.oidc_config import ( - OidcProviderClientSecretJwtKey, - OidcProviderConfig, -) +from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -71,8 +68,8 @@ logger = logging.getLogger(__name__) # # Here we have the names of the cookies, and the options we use to set them. _SESSION_COOKIES = [ - (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), - (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), + (b"oidc_session", b"HttpOnly; Secure; SameSite=None"), + (b"oidc_session_no_samesite", b"HttpOnly"), ] #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and @@ -282,6 +279,13 @@ class OidcProvider: self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str + # Calculate the prefix for OIDC callback paths based on the public_baseurl. + # We'll insert this into the Path= parameter of any session cookies we set. + public_baseurl_path = urlparse(hs.config.server.public_baseurl).path + self._callback_path_prefix = ( + public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc" + ) + self._oidc_attribute_requirements = provider.attribute_requirements self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method @@ -782,8 +786,13 @@ class OidcProvider: for cookie_name, options in _SESSION_COOKIES: request.cookies.append( - b"%s=%s; Max-Age=3600; %s" - % (cookie_name, cookie.encode("utf-8"), options) + b"%s=%s; Max-Age=3600; Path=%s; %s" + % ( + cookie_name, + cookie.encode("utf-8"), + self._callback_path_prefix, + options, + ) ) metadata = await self.load_metadata() @@ -960,6 +969,11 @@ class OidcProvider: # and attempt to match it. attributes = await oidc_response_to_user_attributes(failures=0) + if attributes.localpart is None: + # If no localpart is returned then we will generate one, so + # there is no need to search for existing users. + return None + user_id = UserID(attributes.localpart, self._server_name).to_string() users = await self._store.get_users_by_id_case_insensitive(user_id) if users: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4c4d640ce2..d1d8cfb9f4 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,9 +24,12 @@ The methods that define policy are: import abc import contextlib import logging +from bisect import bisect from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Callable, + Collection, Dict, FrozenSet, Iterable, @@ -53,10 +56,11 @@ from synapse.replication.http.presence import ( ReplicationBumpPresenceActiveTime, ReplicationPresenceSetState, ) +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.tcp.commands import ClearUserSyncsCommand -from synapse.state import StateHandler +from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore -from synapse.types import Collection, JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure @@ -118,11 +122,21 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class BasePresenceHandler(abc.ABC): - """Parts of the PresenceHandler that are shared between workers and master""" + """Parts of the PresenceHandler that are shared between workers and presence + writer""" def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.presence_router = hs.get_presence_router() + self.state = hs.get_state_handler() + self.is_mine_id = hs.is_mine_id + + self._federation = None + if hs.should_send_federation(): + self._federation = hs.get_federation_sender() + + self._federation_queue = PresenceFederationQueue(hs, self) self._busy_presence_enabled = hs.config.experimental.msc3026_enabled @@ -219,23 +233,23 @@ class BasePresenceHandler(abc.ABC): """ async def update_external_syncs_row( - self, process_id, user_id, is_syncing, sync_time_msec - ): + self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + ) -> None: """Update the syncing users for an external process as a delta. This is a no-op when presence is handled by a different worker. Args: - process_id (str): An identifier for the process the users are + process_id: An identifier for the process the users are syncing against. This allows synapse to process updates as user start and stop syncing against a given process. - user_id (str): The user who has started or stopped syncing - is_syncing (bool): Whether or not the user is now syncing - sync_time_msec(int): Time in ms when the user was last syncing + user_id: The user who has started or stopped syncing + is_syncing: Whether or not the user is now syncing + sync_time_msec: Time in ms when the user was last syncing """ pass - async def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process as offline. @@ -245,9 +259,42 @@ class BasePresenceHandler(abc.ABC): """ pass - async def process_replication_rows(self, token, rows): - """Process presence stream rows received over replication.""" - pass + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + """Process streams received over replication.""" + await self._federation_queue.process_replication_rows( + stream_name, instance_name, token, rows + ) + + def get_federation_queue(self) -> "PresenceFederationQueue": + """Get the presence federation queue.""" + return self._federation_queue + + async def maybe_send_presence_to_interested_destinations( + self, states: List[UserPresenceState] + ): + """If this instance is a federation sender, send the states to all + destinations that are interested. Filters out any states for remote + users. + """ + + if not self._federation: + return + + states = [s for s in states if self.is_mine_id(s.user_id)] + + if not states: + return + + hosts_and_states = await get_interested_remotes( + self.store, + self.presence_router, + states, + ) + + for destinations, states in hosts_and_states: + self._federation.send_presence_to_destinations(states, destinations) class _NullContextManager(ContextManager[None]): @@ -258,14 +305,20 @@ class _NullContextManager(ContextManager[None]): class WorkerPresenceHandler(BasePresenceHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self.is_mine_id = hs.is_mine_id - self.presence_router = hs.get_presence_router() + self._presence_writer_instance = hs.config.worker.writers.presence[0] + self._presence_enabled = hs.config.use_presence + # Route presence EDUs to the right worker + hs.get_federation_registry().register_instances_for_edu( + "m.presence", + hs.config.worker.writers.presence, + ) + # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. self._user_to_num_current_syncs = {} # type: Dict[str, int] @@ -273,9 +326,9 @@ class WorkerPresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # user_id -> last_sync_ms. Lists the users that have stopped syncing - # but we haven't notified the master of that yet - self.users_going_offline = {} + # user_id -> last_sync_ms. Lists the users that have stopped syncing but + # we haven't notified the presence writer of that yet + self.users_going_offline = {} # type: Dict[str, int] self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -294,44 +347,39 @@ class WorkerPresenceHandler(BasePresenceHandler): self._on_shutdown, ) - def _on_shutdown(self): + def _on_shutdown(self) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) ) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id): - """A user has started syncing. Send a UserSync to the master, unless they - had recently stopped syncing. - - Args: - user_id (str) + def mark_as_coming_online(self, user_id: str) -> None: + """A user has started syncing. Send a UserSync to the presence writer, + unless they had recently stopped syncing. """ going_offline = self.users_going_offline.pop(user_id, None) if not going_offline: - # Safe to skip because we haven't yet told the master they were offline + # Safe to skip because we haven't yet told the presence writer they + # were offline self.send_user_sync(user_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id): - """A user has stopped syncing. We wait before notifying the master as - its likely they'll come back soon. This allows us to avoid sending - a stopped syncing immediately followed by a started syncing notification - to the master - - Args: - user_id (str) + def mark_as_going_offline(self, user_id: str) -> None: + """A user has stopped syncing. We wait before notifying the presence + writer as its likely they'll come back soon. This allows us to avoid + sending a stopped syncing immediately followed by a started syncing + notification to the presence writer """ self.users_going_offline[user_id] = self.clock.time_msec() - def send_stop_syncing(self): - """Check if there are any users who have stopped syncing a while ago - and haven't come back yet. If there are poke the master about them. + def send_stop_syncing(self) -> None: + """Check if there are any users who have stopped syncing a while ago and + haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() for user_id, last_sync_ms in list(self.users_going_offline.items()): @@ -377,7 +425,9 @@ class WorkerPresenceHandler(BasePresenceHandler): return _user_syncing() - async def notify_from_replication(self, states, stream_id): + async def notify_from_replication( + self, states: List[UserPresenceState], stream_id: int + ) -> None: parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties @@ -388,7 +438,17 @@ class WorkerPresenceHandler(BasePresenceHandler): users=users_to_states.keys(), ) - async def process_replication_rows(self, token, rows): + # If this is a federation sender, notify about presence updates. + await self.maybe_send_presence_to_interested_destinations(states) + + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + await super().process_replication_rows(stream_name, instance_name, token, rows) + + if stream_name != PresenceStream.NAME: + return + states = [ UserPresenceState( row.user_id, @@ -415,7 +475,12 @@ class WorkerPresenceHandler(BasePresenceHandler): if count > 0 ] - async def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state( + self, + target_user: UserID, + state: JsonDict, + ignore_status_msg: bool = False, + ) -> None: """Set the presence state of the user.""" presence = state["presence"] @@ -437,12 +502,15 @@ class WorkerPresenceHandler(BasePresenceHandler): if not self.hs.config.use_presence: return - # Proxy request to master + # Proxy request to instance that writes presence await self._set_state_client( - user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + instance_name=self._presence_writer_instance, + user_id=user_id, + state=state, + ignore_status_msg=ignore_status_msg, ) - async def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -450,22 +518,20 @@ class WorkerPresenceHandler(BasePresenceHandler): if not self.hs.config.use_presence: return - # Proxy request to master + # Proxy request to instance that writes presence user_id = user.to_string() - await self._bump_active_client(user_id=user_id) + await self._bump_active_client( + instance_name=self._presence_writer_instance, user_id=user_id + ) class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.federation = hs.get_federation_sender() - self.state = hs.get_state_handler() - self.presence_router = hs.get_presence_router() self._presence_enabled = hs.config.use_presence federation_registry = hs.get_federation_registry() @@ -528,8 +594,8 @@ class PresenceHandler(BasePresenceHandler): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} # type: Dict[int, Set[str]] - self.external_process_last_updated_ms = {} # type: Dict[int, int] + self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]] + self.external_process_last_updated_ms = {} # type: Dict[str, int] self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -569,7 +635,7 @@ class PresenceHandler(BasePresenceHandler): self._event_pos = self.store.get_current_events_token() self._event_processing = False - async def _on_shutdown(self): + async def _on_shutdown(self) -> None: """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -598,7 +664,7 @@ class PresenceHandler(BasePresenceHandler): ) logger.info("Finished _on_shutdown") - async def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self) -> None: """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -672,6 +738,13 @@ class PresenceHandler(BasePresenceHandler): self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) + # Check if we need to resend any presence states to remote hosts. We + # only do this for states that haven't been updated in a while to + # ensure that the remote host doesn't time the presence state out. + # + # Note that since these are states that have *not* been updated, + # they won't get sent down the normal presence replication stream, + # and so we have to explicitly send them via the federation stream. to_federation_ping = { user_id: state for user_id, state in to_federation_ping.items() @@ -680,9 +753,18 @@ class PresenceHandler(BasePresenceHandler): if to_federation_ping: federation_presence_out_counter.inc(len(to_federation_ping)) - self._push_to_remotes(to_federation_ping.values()) + hosts_and_states = await get_interested_remotes( + self.store, + self.presence_router, + list(to_federation_ping.values()), + ) + + for destinations, states in hosts_and_states: + self._federation_queue.send_presence_to_destinations( + states, destinations + ) - async def _handle_timeouts(self): + async def _handle_timeouts(self) -> None: """Checks the presence of users that have timed out and updates as appropriate. """ @@ -734,7 +816,7 @@ class PresenceHandler(BasePresenceHandler): return await self._update_states(changes) - async def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -831,17 +913,17 @@ class PresenceHandler(BasePresenceHandler): return [] async def update_external_syncs_row( - self, process_id, user_id, is_syncing, sync_time_msec - ): + self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + ) -> None: """Update the syncing users for an external process as a delta. Args: - process_id (str): An identifier for the process the users are + process_id: An identifier for the process the users are syncing against. This allows synapse to process updates as user start and stop syncing against a given process. - user_id (str): The user who has started or stopped syncing - is_syncing (bool): Whether or not the user is now syncing - sync_time_msec(int): Time in ms when the user was last syncing + user_id: The user who has started or stopped syncing + is_syncing: Whether or not the user is now syncing + sync_time_msec: Time in ms when the user was last syncing """ with (await self.external_sync_linearizer.queue(process_id)): prev_state = await self.current_state_for_user(user_id) @@ -878,7 +960,7 @@ class PresenceHandler(BasePresenceHandler): self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - async def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process as offline. @@ -899,12 +981,12 @@ class PresenceHandler(BasePresenceHandler): ) self.external_process_last_updated_ms.pop(process_id, None) - async def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id: str) -> UserPresenceState: """Get the current presence state for a user.""" res = await self.current_state_for_users([user_id]) return res[user_id] - async def _persist_and_notify(self, states): + async def _persist_and_notify(self, states: List[UserPresenceState]) -> None: """Persist states in the database, poke the notifier and send to interested remote servers """ @@ -920,17 +1002,12 @@ class PresenceHandler(BasePresenceHandler): users=[UserID.from_string(u) for u in users_to_states], ) - self._push_to_remotes(states) - - def _push_to_remotes(self, states): - """Sends state updates to remote servers. - - Args: - states (list(UserPresenceState)) - """ - self.federation.send_presence(states) + # We only want to poke the local federation sender, if any, as other + # workers will receive the presence updates via the presence replication + # stream (which is updated by `store.update_presence`). + await self.maybe_send_presence_to_interested_destinations(states) - async def incoming_presence(self, origin, content): + async def incoming_presence(self, origin: str, content: JsonDict) -> None: """Called when we receive a `m.presence` EDU from a remote server.""" if not self._presence_enabled: return @@ -980,7 +1057,9 @@ class PresenceHandler(BasePresenceHandler): federation_presence_counter.inc(len(updates)) await self._update_states(updates) - async def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state( + self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + ) -> None: """Set the presence state of the user.""" status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1014,7 +1093,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states([prev_state.copy_and_replace(**new_fields)]) - async def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: """Returns whether a user can see another user's presence.""" observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() @@ -1069,7 +1148,7 @@ class PresenceHandler(BasePresenceHandler): ) return rows - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when new events have happened. Handles users and servers joining rooms and require being sent presence. """ @@ -1088,7 +1167,7 @@ class PresenceHandler(BasePresenceHandler): run_as_background_process("presence.notify_new_event", _process_presence) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -1113,7 +1192,7 @@ class PresenceHandler(BasePresenceHandler): max_pos ) - async def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas: List[JsonDict]) -> None: """Process current state deltas to find new joins that need to be handled. """ @@ -1166,7 +1245,7 @@ class PresenceHandler(BasePresenceHandler): # Send out user presence updates for each destination for destination, user_state_set in presence_destinations.items(): - self.federation.send_presence_to_destinations( + self._federation_queue.send_presence_to_destinations( destinations=[destination], states=user_state_set ) @@ -1236,7 +1315,7 @@ class PresenceHandler(BasePresenceHandler): return [remote_host], states -def should_notify(old_state, new_state): +def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool: """Decides if a presence state change should be sent to interested parties.""" if old_state == new_state: return False @@ -1272,7 +1351,9 @@ def should_notify(old_state, new_state): return False -def format_user_presence_state(state, now, include_user_id=True): +def format_user_presence_state( + state: UserPresenceState, now: int, include_user_id: bool = True +) -> JsonDict: """Convert UserPresenceState to a format that can be sent down to clients and to other servers. @@ -1305,16 +1386,15 @@ class PresenceEventSource: self.get_presence_router = hs.get_presence_router self.clock = hs.get_clock() self.store = hs.get_datastore() - self.state = hs.get_state_handler() @log_function async def get_new_events( self, - user, - from_key, - room_ids=None, - include_offline=True, - explicit_room_id=None, + user: UserID, + from_key: Optional[int], + room_ids: Optional[List[str]] = None, + include_offline: bool = True, + explicit_room_id: Optional[str] = None, **kwargs, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: @@ -1524,7 +1604,7 @@ class PresenceEventSource: if update.state != PresenceState.OFFLINE ] - def get_current_key(self): + def get_current_key(self) -> int: return self.store.get_current_presence_token() @cached(num_args=2, cache_context=True) @@ -1584,15 +1664,20 @@ class PresenceEventSource: return users_interested_in -def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): +def handle_timeouts( + user_states: List[UserPresenceState], + is_mine_fn: Callable[[str], bool], + syncing_user_ids: Set[str], + now: int, +) -> List[UserPresenceState]: """Checks the presence of users that have timed out and updates as appropriate. Args: - user_states(list): List of UserPresenceState's to check. - is_mine_fn (fn): Function that returns if a user_id is ours - syncing_user_ids (set): Set of user_ids with active syncs. - now (int): Current time in ms. + user_states: List of UserPresenceState's to check. + is_mine_fn: Function that returns if a user_id is ours + syncing_user_ids: Set of user_ids with active syncs. + now: Current time in ms. Returns: List of UserPresenceState updates @@ -1609,14 +1694,16 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): return list(changes.values()) -def handle_timeout(state, is_mine, syncing_user_ids, now): +def handle_timeout( + state: UserPresenceState, is_mine: bool, syncing_user_ids: Set[str], now: int +) -> Optional[UserPresenceState]: """Checks the presence of the user to see if any of the timers have elapsed Args: - state (UserPresenceState) - is_mine (bool): Whether the user is ours - syncing_user_ids (set): Set of user_ids with active syncs. - now (int): Current time in ms. + state + is_mine: Whether the user is ours + syncing_user_ids: Set of user_ids with active syncs. + now: Current time in ms. Returns: A UserPresenceState update or None if no update. @@ -1668,23 +1755,29 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): return state if changed else None -def handle_update(prev_state, new_state, is_mine, wheel_timer, now): +def handle_update( + prev_state: UserPresenceState, + new_state: UserPresenceState, + is_mine: bool, + wheel_timer: WheelTimer, + now: int, +) -> Tuple[UserPresenceState, bool, bool]: """Given a presence update: 1. Add any appropriate timers. 2. Check if we should notify anyone. Args: - prev_state (UserPresenceState) - new_state (UserPresenceState) - is_mine (bool): Whether the user is ours - wheel_timer (WheelTimer) - now (int): Time now in ms + prev_state + new_state + is_mine: Whether the user is ours + wheel_timer + now: Time now in ms Returns: 3-tuple: `(new_state, persist_and_notify, federation_ping)` where: - new_state: is the state to actually persist - - persist_and_notify (bool): whether to persist and notify people - - federation_ping (bool): whether we should send a ping over federation + - persist_and_notify: whether to persist and notify people + - federation_ping: whether we should send a ping over federation """ user_id = new_state.user_id @@ -1779,7 +1872,6 @@ async def get_interested_remotes( store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState], - state_handler: StateHandler, ) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1790,7 +1882,6 @@ async def get_interested_remotes( store: The homeserver's data store. presence_router: A module for augmenting the destinations for presence updates. states: A list of incoming user presence updates. - state_handler: Returns: A list of 2-tuples of destinations and states, where for @@ -1807,7 +1898,8 @@ async def get_interested_remotes( ) for room_id, states in room_ids_to_states.items(): - hosts = await state_handler.get_current_hosts_in_room(room_id) + user_ids = await store.get_users_in_room(room_id) + hosts = {get_domain_from_id(user_id) for user_id in user_ids} hosts_and_states.append((hosts, states)) for user_id, states in users_to_states.items(): @@ -1815,3 +1907,220 @@ async def get_interested_remotes( hosts_and_states.append(([host], states)) return hosts_and_states + + +class PresenceFederationQueue: + """Handles sending ad hoc presence updates over federation, which are *not* + due to state updates (that get handled via the presence stream), e.g. + federation pings and sending existing present states to newly joined hosts. + + Only the last N minutes will be queued, so if a federation sender instance + is down for longer then some updates will be dropped. This is OK as presence + is ephemeral, and so it will self correct eventually. + + On workers the class tracks the last received position of the stream from + replication, and handles querying for missed updates over HTTP replication, + c.f. `get_current_token` and `get_replication_rows`. + """ + + # How long to keep entries in the queue for. Workers that are down for + # longer than this duration will miss out on older updates. + _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000 + + # How often to check if we can expire entries from the queue. + _CLEAR_ITEMS_EVERY_MS = 60 * 1000 + + def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler): + self._clock = hs.get_clock() + self._notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() + self._presence_handler = presence_handler + self._repl_client = ReplicationGetStreamUpdates.make_client(hs) + + # Should we keep a queue of recent presence updates? We only bother if + # another process may be handling federation sending. + self._queue_presence_updates = True + + # Whether this instance is a presence writer. + self._presence_writer = self._instance_name in hs.config.worker.writers.presence + + # The FederationSender instance, if this process sends federation traffic directly. + self._federation = None + + if hs.should_send_federation(): + self._federation = hs.get_federation_sender() + + # We don't bother queuing up presence states if only this instance + # is sending federation. + if hs.config.worker.federation_shard_config.instances == [ + self._instance_name + ]: + self._queue_presence_updates = False + + # The queue of recently queued updates as tuples of: `(timestamp, + # stream_id, destinations, user_ids)`. We don't store the full states + # for efficiency, and remote workers will already have the full states + # cached. + self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]] + + self._next_id = 1 + + # Map from instance name to current token + self._current_tokens = {} # type: Dict[str, int] + + if self._queue_presence_updates: + self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) + + def _clear_queue(self): + """Clear out older entries from the queue.""" + clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS + + # The queue is sorted by timestamp, so we can bisect to find the right + # place to purge before. Note that we are searching using a 1-tuple with + # the time, which does The Right Thing since the queue is a tuple where + # the first item is a timestamp. + index = bisect(self._queue, (clear_before,)) + self._queue = self._queue[index:] + + def send_presence_to_destinations( + self, states: Collection[UserPresenceState], destinations: Collection[str] + ) -> None: + """Send the presence states to the given destinations. + + Will forward to the local federation sender (if there is one) and queue + to send over replication (if there are other federation sender instances.). + + Must only be called on the presence writer process. + """ + + # This should only be called on a presence writer. + assert self._presence_writer + + if self._federation: + self._federation.send_presence_to_destinations( + states=states, + destinations=destinations, + ) + + if not self._queue_presence_updates: + return + + now = self._clock.time_msec() + + stream_id = self._next_id + self._next_id += 1 + + self._queue.append((now, stream_id, destinations, {s.user_id for s in states})) + + self._notifier.notify_replication() + + def get_current_token(self, instance_name: str) -> int: + """Get the current position of the stream. + + On workers this returns the last stream ID received from replication. + """ + if instance_name == self._instance_name: + return self._next_id - 1 + else: + return self._current_tokens.get(instance_name, 0) + + async def get_replication_rows( + self, + instance_name: str, + from_token: int, + upto_token: int, + target_row_count: int, + ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]: + """Get all the updates between the two tokens. + + We return rows in the form of `(destination, user_id)` to keep the size + of each row bounded (rather than returning the sets in a row). + + On workers this will query the presence writer process via HTTP replication. + """ + if instance_name != self._instance_name: + # If not local we query over http replication from the presence + # writer + result = await self._repl_client( + instance_name=instance_name, + stream_name=PresenceFederationStream.NAME, + from_token=from_token, + upto_token=upto_token, + ) + return result["updates"], result["upto_token"], result["limited"] + + # If the from_token is the current token then there's nothing to return + # and we can trivially no-op. + if from_token == self._next_id - 1: + return [], upto_token, False + + # We can find the correct position in the queue by noting that there is + # exactly one entry per stream ID, and that the last entry has an ID of + # `self._next_id - 1`, so we can count backwards from the end. + # + # Since we are returning all states in the range `from_token < stream_id + # <= upto_token` we look for the index with a `stream_id` of `from_token + # + 1`. + # + # Since the start of the queue is periodically truncated we need to + # handle the case where `from_token` stream ID has already been dropped. + start_idx = max(from_token + 1 - self._next_id, -len(self._queue)) + + to_send = [] # type: List[Tuple[int, Tuple[str, str]]] + limited = False + new_id = upto_token + for _, stream_id, destinations, user_ids in self._queue[start_idx:]: + if stream_id <= from_token: + # Paranoia check that we are actually only sending states that + # are have stream_id strictly greater than from_token. We should + # never hit this. + logger.warning( + "Tried returning presence federation stream ID: %d less than from_token: %d (next_id: %d, len: %d)", + stream_id, + from_token, + self._next_id, + len(self._queue), + ) + continue + + if stream_id > upto_token: + break + + new_id = stream_id + + to_send.extend( + (stream_id, (destination, user_id)) + for destination in destinations + for user_id in user_ids + ) + + if len(to_send) > target_row_count: + limited = True + break + + return to_send, new_id, limited + + async def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + if stream_name != PresenceFederationStream.NAME: + return + + # We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect) + self._current_tokens[instance_name] = token + + # If we're a federation sender we pull out the presence states to send + # and forward them on. + if not self._federation: + return + + hosts_to_users = {} # type: Dict[str, Set[str]] + for row in rows: + hosts_to_users.setdefault(row.destination, set()).add(row.user_id) + + for host, user_ids in hosts_to_users.items(): + states = await self._presence_handler.current_state_for_users(user_ids) + self._federation.send_presence_to_destinations( + states=states.values(), + destinations=[host], + ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2bbfac6471..20700fc5a8 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -28,7 +28,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID @@ -64,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() + self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer = Linearizer(name="member") @@ -178,62 +178,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) - async def _can_join_without_invite( - self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str - ) -> bool: - """ - Check whether a user can join a room without an invite. - - When joining a room with restricted joined rules (as defined in MSC3083), - the membership of spaces must be checked during join. - - Args: - state_ids: The state of the room as it currently is. - room_version: The room version of the room being joined. - user_id: The user joining the room. - - Returns: - True if the user can join the room, false otherwise. - """ - # This only applies to room versions which support the new join rule. - if not room_version.msc3083_join_rules: - return True - - # If there's no join rule, then it defaults to public (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) - if not join_rules_event_id: - return True - - # If the join rule is not restricted, this doesn't apply. - join_rules_event = await self.store.get_event(join_rules_event_id) - if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: - return True - - # If allowed is of the wrong form, then only allow invited users. - allowed_spaces = join_rules_event.content.get("allow", []) - if not isinstance(allowed_spaces, list): - return False - - # Get the list of joined rooms and see if there's an overlap. - joined_rooms = await self.store.get_rooms_for_user(user_id) - - # Pull out the other room IDs, invalid data gets filtered. - for space in allowed_spaces: - if not isinstance(space, dict): - continue - - space_id = space.get("space") - if not isinstance(space_id, str): - continue - - # The user was joined to one of the spaces specified, they can join - # this room! - if space_id in joined_rooms: - return True - - # The user was not in any of the required spaces. - return False - async def _local_membership_update( self, requester: Requester, @@ -302,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if ( newly_joined and not user_is_invited - and not await self._can_join_without_invite( + and not await self.event_auth_handler.can_join_without_invite( prev_state_ids, event.room_version, user_id ) ): @@ -1100,7 +1044,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): class RoomMemberMasterHandler(RoomMemberHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.distributor = hs.get_distributor() diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml.py index 80ba65b9e0..80ba65b9e0 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml.py diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 8d00ffdc73..044ff06d84 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -18,6 +18,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -40,7 +41,7 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dc8ee8cd17..a9a3ee05c3 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -14,7 +14,17 @@ # limitations under the License. import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + FrozenSet, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -28,7 +38,6 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( - Collection, JsonDict, MutableStateMap, Requester, diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 0eeb7c03f2..5414ce77d8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any from twisted.web.client import PartialDownloadError @@ -22,13 +22,16 @@ from synapse.api.errors import Codes, LoginError, SynapseError from synapse.config.emailconfig import ThreepidBehaviour from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class UserInteractiveAuthChecker: """Abstract base class for an interactive auth checker""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): pass def is_enabled(self) -> bool: @@ -57,10 +60,10 @@ class UserInteractiveAuthChecker: class DummyAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.DUMMY - def is_enabled(self): + def is_enabled(self) -> bool: return True - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return True @@ -70,24 +73,24 @@ class TermsAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return True class RecaptchaAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.RECAPTCHA - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key - def is_enabled(self): + def is_enabled(self) -> bool: return self._enabled - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: try: user_response = authdict["response"] except KeyError: @@ -132,11 +135,11 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): class _BaseThreepidAuthChecker: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() - async def _check_threepid(self, medium, authdict): + async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) @@ -206,31 +209,31 @@ class _BaseThreepidAuthChecker: class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): AUTH_TYPE = LoginType.EMAIL_IDENTITY - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): UserInteractiveAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs) - def is_enabled(self): + def is_enabled(self) -> bool: return self.hs.config.threepid_behaviour_email in ( ThreepidBehaviour.REMOTE, ThreepidBehaviour.LOCAL, ) - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("email", authdict) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): AUTH_TYPE = LoginType.MSISDN - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): UserInteractiveAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs) - def is_enabled(self): + def is_enabled(self) -> bool: return bool(self.hs.config.account_threepid_delegate_msisdn) - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 9b1e6d5c18..dacc4f3076 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -44,7 +44,6 @@ class UserDirectoryHandler(StateDeltasHandler): super().__init__(hs) self.store = hs.get_datastore() - self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -302,10 +301,12 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - users_with_profile = await self.state.get_current_users_in_room(room_id) + other_users_in_room_with_profiles = ( + await self.store.get_users_in_room_with_profiles(room_id) + ) # Remove every user from the sharing tables for that room. - for user_id in users_with_profile.keys(): + for user_id in other_users_in_room_with_profiles.keys(): await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. @@ -314,7 +315,7 @@ class UserDirectoryHandler(StateDeltasHandler): # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. - for user_id, profile in users_with_profile.items(): + for user_id, profile in other_users_in_room_with_profiles.items(): await self._handle_new_user(room_id, user_id, profile) async def _handle_new_user( @@ -336,7 +337,7 @@ class UserDirectoryHandler(StateDeltasHandler): room_id ) # Now we update users who share rooms with users. - users_with_profile = await self.state.get_current_users_in_room(room_id) + other_users_in_room = await self.store.get_users_in_room(room_id) if is_public: await self.store.add_users_in_public_rooms(room_id, (user_id,)) @@ -352,14 +353,14 @@ class UserDirectoryHandler(StateDeltasHandler): # We don't care about appservice users. if not is_appservice: - for other_user_id in users_with_profile: + for other_user_id in other_users_in_room: if user_id == other_user_id: continue to_insert.add((user_id, other_user_id)) # Next we need to update for every local user in the room - for other_user_id in users_with_profile: + for other_user_id in other_users_in_room: if user_id == other_user_id: continue diff --git a/synapse/http/client.py b/synapse/http/client.py index 1730187ffa..5f40f16e24 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -33,6 +33,7 @@ import treq from canonicaljson import encode_canonical_json from netaddr import AddrFormatError, IPAddress, IPSet from prometheus_client import Counter +from typing_extensions import Protocol from zope.interface import implementer, provider from OpenSSL import SSL @@ -754,6 +755,16 @@ def _timeout_to_request_timed_out_error(f: Failure): return f +class ByteWriteable(Protocol): + """The type of object which must be passed into read_body_with_max_size. + + Typically this is a file object. + """ + + def write(self, data: bytes) -> int: + pass + + class BodyExceededMaxSize(Exception): """The maximum allowed size of the HTTP body was exceeded.""" @@ -790,7 +801,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): transport = None # type: Optional[ITCPTransport] def __init__( - self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] + self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int] ): self.stream = stream self.deferred = deferred @@ -830,7 +841,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): def read_body_with_max_size( - response: IResponse, stream: BinaryIO, max_size: Optional[int] + response: IResponse, stream: ByteWriteable, max_size: Optional[int] ) -> defer.Deferred: """ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index d48721a4e2..bb837b7b19 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import cgi +import codecs import logging import random import sys +import typing import urllib.parse -from io import BytesIO +from io import BytesIO, StringIO from typing import Callable, Dict, List, Optional, Tuple, Union import attr @@ -72,6 +73,9 @@ incoming_responses_counter = Counter( "synapse_http_matrixfederationclient_responses", "", ["method", "code"] ) +# a federation response can be rather large (eg a big state_ids is 50M or so), so we +# need a generous limit here. +MAX_RESPONSE_SIZE = 100 * 1024 * 1024 MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 @@ -167,12 +171,27 @@ async def _handle_json_response( try: check_content_type_is_json(response.headers) - # Use the custom JSON decoder (partially re-implements treq.json_content). - d = treq.text_content(response, encoding="utf-8") - d.addCallback(json_decoder.decode) + buf = StringIO() + d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) + def parse(_len: int): + return json_decoder.decode(buf.getvalue()) + + d.addCallback(parse) + body = await make_deferred_yieldable(d) + except BodyExceededMaxSize as e: + # The response was too big. + logger.warning( + "{%s} [%s] JSON response exceeded max size %i - %s %s", + request.txn_id, + request.destination, + MAX_RESPONSE_SIZE, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=False) from e except ValueError as e: # The JSON content was invalid. logger.warning( @@ -218,6 +237,18 @@ async def _handle_json_response( return body +class BinaryIOWrapper: + """A wrapper for a TextIO which converts from bytes on the fly.""" + + def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"): + self.decoder = codecs.getincrementaldecoder(encoding)(errors) + self.file = file + + def write(self, b: Union[bytes, bytearray]) -> int: + self.file.write(self.decoder.decode(b)) + return len(b) + + class MatrixFederationHttpClient: """HTTP client used to talk to other homeservers over the federation protocol. Send client certificates and signs requests. diff --git a/synapse/http/site.py b/synapse/http/site.py index 32b5e19c09..671fd3fbcc 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,13 +14,14 @@ import contextlib import logging import time -from typing import Optional, Tuple, Type, Union +from typing import Optional, Tuple, Union import attr from zope.interface import implementer -from twisted.internet.interfaces import IAddress +from twisted.internet.interfaces import IAddress, IReactorTime from twisted.python.failure import Failure +from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.server import ListenerConfig @@ -49,6 +50,7 @@ class SynapseRequest(Request): * Redaction of access_token query-params in __repr__ * Logging at start and end * Metrics to record CPU, wallclock and DB time by endpoint. + * A limit to the size of request which will be accepted It also provides a method `processing`, which returns a context manager. If this method is called, the request won't be logged until the context manager is closed; @@ -59,8 +61,9 @@ class SynapseRequest(Request): logcontext: the log context for this request """ - def __init__(self, channel, *args, **kw): + def __init__(self, channel, *args, max_request_body_size=1024, **kw): Request.__init__(self, channel, *args, **kw) + self._max_request_body_size = max_request_body_size self.site = channel.site # type: SynapseSite self._channel = channel # this is used by the tests self.start_time = 0.0 @@ -97,6 +100,18 @@ class SynapseRequest(Request): self.site.site_tag, ) + def handleContentChunk(self, data): + # we should have a `content` by now. + assert self.content, "handleContentChunk() called before gotLength()" + if self.content.tell() + len(data) > self._max_request_body_size: + logger.warning( + "Aborting connection from %s because the request exceeds maximum size", + self.client, + ) + self.transport.abortConnection() + return + super().handleContentChunk(data) + @property def requester(self) -> Optional[Union[Requester, str]]: return self._requester @@ -485,29 +500,55 @@ class _XForwardedForAddress: class SynapseSite(Site): """ - Subclass of a twisted http Site that does access logging with python's - standard logging + Synapse-specific twisted http Site + + This does two main things. + + First, it replaces the requestFactory in use so that we build SynapseRequests + instead of regular t.w.server.Requests. All of the constructor params are really + just parameters for SynapseRequest. + + Second, it inhibits the log() method called by Request.finish, since SynapseRequest + does its own logging. """ def __init__( self, - logger_name, - site_tag, + logger_name: str, + site_tag: str, config: ListenerConfig, - resource, + resource: IResource, server_version_string, - *args, - **kwargs, + max_request_body_size: int, + reactor: IReactorTime, ): - Site.__init__(self, resource, *args, **kwargs) + """ + + Args: + logger_name: The name of the logger to use for access logs. + site_tag: A tag to use for this site - mostly in access logs. + config: Configuration for the HTTP listener corresponding to this site + resource: The base of the resource tree to be used for serving requests on + this site + server_version_string: A string to present for the Server header + max_request_body_size: Maximum request body length to allow before + dropping the connection + reactor: reactor to be used to manage connection timeouts + """ + Site.__init__(self, resource, reactor=reactor) self.site_tag = site_tag assert config.http_options is not None proxied = config.http_options.x_forwarded - self.requestFactory = ( - XForwardedForRequest if proxied else SynapseRequest - ) # type: Type[Request] + request_class = XForwardedForRequest if proxied else SynapseRequest + + def request_factory(channel, queued) -> Request: + return request_class( + channel, max_request_body_size=max_request_body_size, queued=queued + ) + + self.requestFactory = request_factory # type: ignore self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 4e8b0f8d10..c515690b38 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -226,11 +226,11 @@ class RemoteHandler(logging.Handler): old_buffer = self._buffer self._buffer = deque() - for i in range(buffer_split): + for _ in range(buffer_split): self._buffer.append(old_buffer.popleft()) end_buffer = [] - for i in range(buffer_split): + for _ in range(buffer_split): end_buffer.append(old_buffer.pop()) self._buffer.extend(reversed(end_buffer)) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index e78343f554..7fc11a9ac2 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -258,7 +258,8 @@ class LoggingContext: child to the parent Args: - name (str): Name for the context for debugging. + name: Name for the context for logging. If this is omitted, it is + inherited from the parent context. parent_context (LoggingContext|None): The parent of the new context """ @@ -282,7 +283,6 @@ class LoggingContext: request: Optional[ContextRequest] = None, ) -> None: self.previous_context = current_context() - self.name = name # track the resources used by this context so far self._resource_usage = ContextResourceUsage() @@ -314,10 +314,17 @@ class LoggingContext: # the request param overrides the request from the parent context self.request = request + # if we don't have a `name`, but do have a parent context, use its name. + if self.parent_context and name is None: + name = str(self.parent_context) + if name is None: + raise ValueError( + "LoggingContext must be given either a name or a parent context" + ) + self.name = name + def __str__(self) -> str: - if self.request: - return self.request.request_id - return "%s@%x" % (self.name, id(self)) + return self.name @classmethod def current_context(cls) -> LoggingContextOrSentinel: @@ -694,17 +701,13 @@ def nested_logging_context(suffix: str) -> LoggingContext: "Starting nested logging context from sentinel context: metrics will be lost" ) parent_context = None - prefix = "" - request = None else: assert isinstance(curr_context, LoggingContext) parent_context = curr_context - prefix = str(parent_context.name) - request = parent_context.request + prefix = str(curr_context) return LoggingContext( prefix + "-" + suffix, parent_context=parent_context, - request=request, ) @@ -895,7 +898,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): parent_context = curr_context def g(): - with LoggingContext(parent_context=parent_context): + with LoggingContext(str(curr_context), parent_context=parent_context): return f(*args, **kwargs) return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index cbd0894e57..714caf84c3 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -241,19 +241,24 @@ class BackgroundProcessLoggingContext(LoggingContext): processes. """ - __slots__ = ["_id", "_proc"] - - def __init__(self, name: str, id: Optional[Union[int, str]] = None): - super().__init__(name) - self._id = id - + __slots__ = ["_proc"] + + def __init__(self, name: str, instance_id: Optional[Union[int, str]] = None): + """ + + Args: + name: The name of the background process. Each distinct `name` gets a + separate prometheus time series. + + instance_id: an identifer to add to `name` to distinguish this instance of + the named background process in the logs. If this is `None`, one is + made up based on id(self). + """ + if instance_id is None: + instance_id = id(self) + super().__init__("%s-%s" % (name, instance_id)) self._proc = _BackgroundProcess(name, self) - def __str__(self) -> str: - if self._id is not None: - return "%s-%s" % (self.name, self._id) - return "%s@%x" % (self.name, id(self)) - def start(self, rusage: "Optional[resource._RUsage]"): """Log context has started running (again).""" diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b114885809..5c7cc8d718 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -50,6 +50,7 @@ class ModuleApi: self._auth_handler = auth_handler self._server_name = hs.hostname self._presence_stream = hs.get_event_sources().sources["presence"] + self._state = hs.get_state_handler() # We expose these as properties below in order to attach a helpful docstring. self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient @@ -399,11 +400,9 @@ class ModuleApi: Note that this method can only be run on the main or federation_sender worker processes. """ - if not self._hs.should_send_federation(): - raise Exception( - "send_local_online_presence_to can only be run " - "on processes that send federation", - ) + # We pull out the presence handler here to break a cyclic + # dependency between the presence router and module API. + presence_handler = self._hs.get_presence_handler() local_users = set() remote_users = set() @@ -413,6 +412,12 @@ class ModuleApi: else: remote_users.add(user) + if remote_users and not self._hs.should_send_federation(): + raise Exception( + "send_local_online_presence_to can only be called with remote users " + "on processes that send federation", + ) + if local_users: # Force a presence initial_sync for these users next time they sync. await self._store.add_users_to_send_full_presence_to(local_users) @@ -424,11 +429,9 @@ class ModuleApi: UserID.from_string(user), from_key=None, include_offline=False ) - # Send to remote destinations - await make_deferred_yieldable( - # We pull the federation sender here as we can only do so on workers - # that support sending presence - self._hs.get_federation_sender().send_presence(presence_events) + # Send to remote destinations. + await presence_handler.maybe_send_presence_to_interested_destinations( + presence_events ) diff --git a/synapse/notifier.py b/synapse/notifier.py index d5ab77058d..b9531007e2 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -17,6 +17,7 @@ from collections import namedtuple from typing import ( Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -42,13 +43,7 @@ from synapse.logging.opentracing import log_kv, start_active_span from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig -from synapse.types import ( - Collection, - PersistedEventPosition, - RoomStreamToken, - StreamToken, - UserID, -) +from synapse.types import PersistedEventPosition, RoomStreamToken, StreamToken, UserID from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 50b470c310..350646f458 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -106,6 +106,10 @@ class BulkPushRuleEvaluator: self.store = hs.get_datastore() self.auth = hs.get_auth() + # Used by `RulesForRoom` to ensure only one thing mutates the cache at a + # time. Keyed off room_id. + self._rules_linearizer = Linearizer(name="rules_for_room") + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -123,7 +127,16 @@ class BulkPushRuleEvaluator: dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + + rules_for_room_data = self._get_rules_for_room(room_id) + rules_for_room = RulesForRoom( + hs=self.hs, + room_id=room_id, + rules_for_room_cache=self._get_rules_for_room.cache, + room_push_rule_cache_metrics=self.room_push_rule_cache_metrics, + linearizer=self._rules_linearizer, + cached_data=rules_for_room_data, + ) rules_by_user = await rules_for_room.get_rules(event, context) @@ -142,17 +155,12 @@ class BulkPushRuleEvaluator: return rules_by_user @lru_cache() - def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": - """Get the current RulesForRoom object for the given room id""" - # It's important that RulesForRoom gets added to self._get_rules_for_room.cache + def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData": + """Get the current RulesForRoomData object for the given room id""" + # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) - return RulesForRoom( - self.hs, - room_id, - self._get_rules_for_room.cache, - self.room_push_rule_cache_metrics, - ) + return RulesForRoomData() async def _get_power_levels_and_sender_level( self, event: EventBase, context: EventContext @@ -282,11 +290,49 @@ def _condition_checker( return True +@attr.s(slots=True) +class RulesForRoomData: + """The data stored in the cache by `RulesForRoom`. + + We don't store `RulesForRoom` directly in the cache as we want our caches to + *only* include data, and not references to e.g. the data stores. + """ + + # event_id -> (user_id, state) + member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict) + # user_id -> rules + rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict) + + # The last state group we updated the caches for. If the state_group of + # a new event comes along, we know that we can just return the cached + # result. + # On invalidation of the rules themselves (if the user changes them), + # we invalidate everything and set state_group to `object()` + state_group = attr.ib(type=Union[object, int], factory=object) + + # A sequence number to keep track of when we're allowed to update the + # cache. We bump the sequence number when we invalidate the cache. If + # the sequence number changes while we're calculating stuff we should + # not update the cache with it. + sequence = attr.ib(type=int, default=0) + + # A cache of user_ids that we *know* aren't interesting, e.g. user_ids + # owned by AS's, or remote users, etc. (I.e. users we will never need to + # calculate push for) + # These never need to be invalidated as we will never set up push for + # them. + uninteresting_user_set = attr.ib(type=Set[str], factory=set) + + class RulesForRoom: """Caches push rules for users in a room. This efficiently handles users joining/leaving the room by not invalidating the entire cache for the room. + + A new instance is constructed for each call to + `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from + previous calls passed in. """ def __init__( @@ -295,6 +341,8 @@ class RulesForRoom: room_id: str, rules_for_room_cache: LruCache, room_push_rule_cache_metrics: CacheMetric, + linearizer: Linearizer, + cached_data: RulesForRoomData, ): """ Args: @@ -303,38 +351,21 @@ class RulesForRoom: rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics: The metrics object + linearizer: The linearizer used to ensure only one thing mutates + the cache at a time. Keyed off room_id + cached_data: Cached data from previous calls to `self.get_rules`, + can be mutated. """ self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics - self.linearizer = Linearizer(name="rules_for_room") - - # event_id -> (user_id, state) - self.member_map = {} # type: Dict[str, Tuple[str, str]] - # user_id -> rules - self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]] - - # The last state group we updated the caches for. If the state_group of - # a new event comes along, we know that we can just return the cached - # result. - # On invalidation of the rules themselves (if the user changes them), - # we invalidate everything and set state_group to `object()` - self.state_group = object() - - # A sequence number to keep track of when we're allowed to update the - # cache. We bump the sequence number when we invalidate the cache. If - # the sequence number changes while we're calculating stuff we should - # not update the cache with it. - self.sequence = 0 - - # A cache of user_ids that we *know* aren't interesting, e.g. user_ids - # owned by AS's, or remote users, etc. (I.e. users we will never need to - # calculate push for) - # These never need to be invalidated as we will never set up push for - # them. - self.uninteresting_user_set = set() # type: Set[str] + # Used to ensure only one thing mutates the cache at a time. Keyed off + # room_id. + self.linearizer = linearizer + + self.data = cached_data # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, @@ -352,25 +383,25 @@ class RulesForRoom: """ state_group = context.state_group - if state_group and self.state_group == state_group: + if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - return self.rules_by_user + return self.data.rules_by_user - with (await self.linearizer.queue(())): - if state_group and self.state_group == state_group: + with (await self.linearizer.queue(self.room_id)): + if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - return self.rules_by_user + return self.data.rules_by_user self.room_push_rule_cache_metrics.inc_misses() ret_rules_by_user = {} missing_member_event_ids = {} - if state_group and self.state_group == context.prev_group: + if state_group and self.data.state_group == context.prev_group: # If we have a simple delta then we can reuse most of the previous # results. - ret_rules_by_user = self.rules_by_user + ret_rules_by_user = self.data.rules_by_user current_state_ids = context.delta_ids push_rules_delta_state_cache_metric.inc_hits() @@ -393,24 +424,24 @@ class RulesForRoom: if typ != EventTypes.Member: continue - if user_id in self.uninteresting_user_set: + if user_id in self.data.uninteresting_user_set: continue if not self.is_mine_id(user_id): - self.uninteresting_user_set.add(user_id) + self.data.uninteresting_user_set.add(user_id) continue if self.store.get_if_app_services_interested_in_user(user_id): - self.uninteresting_user_set.add(user_id) + self.data.uninteresting_user_set.add(user_id) continue event_id = current_state_ids[key] - res = self.member_map.get(event_id, None) + res = self.data.member_map.get(event_id, None) if res: user_id, state = res if state == Membership.JOIN: - rules = self.rules_by_user.get(user_id, None) + rules = self.data.rules_by_user.get(user_id, None) if rules: ret_rules_by_user[user_id] = rules continue @@ -430,7 +461,7 @@ class RulesForRoom: else: # The push rules didn't change but lets update the cache anyway self.update_cache( - self.sequence, + self.data.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, state_group=state_group, @@ -461,7 +492,7 @@ class RulesForRoom: for. Used when updating the cache. event: The event we are currently computing push rules for. """ - sequence = self.sequence + sequence = self.data.sequence rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) @@ -501,23 +532,11 @@ class RulesForRoom: self.update_cache(sequence, members, ret_rules_by_user, state_group) - def invalidate_all(self) -> None: - # Note: Don't hand this function directly to an invalidation callback - # as it keeps a reference to self and will stop this instance from being - # GC'd if it gets dropped from the rules_to_user cache. Instead use - # `self.invalidate_all_cb` - logger.debug("Invalidating RulesForRoom for %r", self.room_id) - self.sequence += 1 - self.state_group = object() - self.member_map = {} - self.rules_by_user = {} - push_rules_invalidation_counter.inc() - def update_cache(self, sequence, members, rules_by_user, state_group) -> None: - if sequence == self.sequence: - self.member_map.update(members) - self.rules_by_user = rules_by_user - self.state_group = state_group + if sequence == self.data.sequence: + self.data.member_map.update(members) + self.data.rules_by_user = rules_by_user + self.data.state_group = state_group @attr.attrs(slots=True, frozen=True) @@ -535,6 +554,10 @@ class _Invalidation: room_id = attr.ib(type=str) def __call__(self) -> None: - rules = self.cache.get(self.room_id, None, update_metrics=False) - if rules: - rules.invalidate_all() + rules_data = self.cache.get(self.room_id, None, update_metrics=False) + if rules_data: + rules_data.sequence += 1 + rules_data.state_group = object() + rules_data.member_map = {} + rules_data.rules_by_user = {} + push_rules_invalidation_counter.inc() diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index cd89b54305..99a18874d1 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -19,8 +19,9 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher, PusherConfig, ThrottleParams +from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer +from synapse.util.threepids import validate_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -71,6 +72,12 @@ class EmailPusher(Pusher): self._is_processing = False + # Make sure that the email is valid. + try: + validate_email(self.email) + except ValueError: + raise PusherConfigException("Invalid email") + def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 564a5ed0df..579fcdf472 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,7 +62,9 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity = hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config @@ -236,7 +238,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) @@ -266,7 +268,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2a1c925ee8..2de946f464 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -85,7 +85,7 @@ REQUIREMENTS = [ "typing-extensions>=3.7.4", # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. - "cryptography>=3.4.7;python_version>='3.6'", + "cryptography>=3.4.7", ] CONDITIONAL_REQUIREMENTS = { @@ -100,14 +100,9 @@ CONDITIONAL_REQUIREMENTS = { # that use the protocol, such as Let's Encrypt. "acme": [ "txacme>=0.9.2", - # txacme depends on eliot. Eliot 1.8.0 is incompatible with - # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 - "eliot<1.8.0;python_version<'3.5.3'", ], "saml2": [ - # pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749) - "pysaml2>=4.5.0,<6.4.0;python_version<'3.6'", - "pysaml2>=4.5.0;python_version>='3.6'", + "pysaml2>=4.5.0", ], "oidc": ["authlib>=0.14.0"], # systemd-python is necessary for logging to the systemd journal via diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index ece03467b5..5685cf2121 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -158,7 +158,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): def make_client(cls, hs): """Create a client that makes requests. - Returns a callable that accepts the same parameters as `_serialize_payload`. + Returns a callable that accepts the same parameters as + `_serialize_payload`, and also accepts an optional `instance_name` + parameter to specify which instance to hit (the instance must be in + the `instance_map` config). """ clock = hs.get_clock() client = hs.get_simple_http_client() diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py deleted file mode 100644 index 57327d910d..0000000000 --- a/synapse/replication/slave/storage/presence.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.tcp.streams import PresenceStream -from synapse.storage import DataStore -from synapse.storage.database import DatabasePool -from synapse.storage.databases.main.presence import PresenceStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - - -class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") - - self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore - - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() - ) - - _get_active_presence = DataStore._get_active_presence - take_presence_startup_info = DataStore.take_presence_startup_info - _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] - get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] - - def get_current_presence_token(self): - return self._presence_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) - for row in rows: - self.presence_stream_cache.entity_has_changed(row.user_id, token) - self._get_presence_for_user.invalidate((row.user_id,)) - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index ce5d651cb8..4f3c6a18b6 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,6 @@ from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, GroupServerStream, - PresenceStream, PushersStream, PushRulesStream, ReceiptsStream, @@ -191,8 +190,6 @@ class ReplicationDataHandler: self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: await self.start_pusher(row.user_id, row.app_id, row.pushkey) - elif stream_name == PresenceStream.NAME: - await self._presence_handler.process_replication_rows(token, rows) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -221,6 +218,10 @@ class ReplicationDataHandler: membership=row.data.membership, ) + await self._presence_handler.process_replication_rows( + stream_name, instance_name, token, rows + ) + # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2ce1b9f222..7ced4c543c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -55,6 +55,8 @@ from synapse.replication.tcp.streams import ( CachesStream, EventsStream, FederationStream, + PresenceFederationStream, + PresenceStream, ReceiptsStream, Stream, TagAccountDataStream, @@ -99,6 +101,10 @@ class ReplicationCommandHandler: self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() + self._is_presence_writer = ( + hs.get_instance_name() in hs.config.worker.writers.presence + ) + self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] @@ -153,6 +159,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (PresenceStream, PresenceFederationStream)): + # Only add PresenceStream as a source on the instance in charge + # of presence. + if self._is_presence_writer: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -350,7 +364,7 @@ class ReplicationCommandHandler: ) -> Optional[Awaitable[None]]: user_sync_counter.inc() - if self._is_master: + if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) @@ -360,7 +374,7 @@ class ReplicationCommandHandler: def on_CLEAR_USER_SYNC( self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand ) -> Optional[Awaitable[None]]: - if self._is_master: + if self._is_presence_writer: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: return None diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 6860576e78..6e3705364f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,7 +49,7 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Collection, List, Optional from prometheus_client import Counter from zope.interface import Interface, implementer @@ -76,7 +76,6 @@ from synapse.replication.tcp.commands import ( ServerCommand, parse_command_from_line, ) -from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index fb74ac4e98..4c0023c68a 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -30,6 +30,7 @@ from synapse.replication.tcp.streams._base import ( CachesStream, DeviceListsStream, GroupServerStream, + PresenceFederationStream, PresenceStream, PublicRoomsStream, PushersStream, @@ -50,6 +51,7 @@ STREAMS_MAP = { EventsStream, BackfillStream, PresenceStream, + PresenceFederationStream, TypingStream, ReceiptsStream, PushRulesStream, @@ -71,6 +73,7 @@ __all__ = [ "Stream", "BackfillStream", "PresenceStream", + "PresenceFederationStream", "TypingStream", "ReceiptsStream", "PushRulesStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 520c45f151..b03824925a 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -272,15 +272,22 @@ class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() - if hs.config.worker_app is None: - # on the master, query the presence handler + if hs.get_instance_name() in hs.config.worker.writers.presence: + # on the presence writer, query the presence handler presence_handler = hs.get_presence_handler() - update_function = presence_handler.get_all_presence_updates + + from synapse.handlers.presence import PresenceHandler + + assert isinstance(presence_handler, PresenceHandler) + + update_function = ( + presence_handler.get_all_presence_updates + ) # type: UpdateFunction else: - # Query master process + # Query presence writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( @@ -290,6 +297,30 @@ class PresenceStream(Stream): ) +class PresenceFederationStream(Stream): + """A stream used to send ad hoc presence updates over federation. + + Streams the remote destination and the user ID of the presence state to + send. + """ + + @attr.s(slots=True, auto_attribs=True) + class PresenceFederationStreamRow: + destination: str + user_id: str + + NAME = "presence_federation" + ROW_TYPE = PresenceFederationStreamRow + + def __init__(self, hs: "HomeServer"): + federation_queue = hs.get_presence_handler().get_federation_queue() + super().__init__( + hs.get_instance_name(), + federation_queue.get_current_token, + federation_queue.get_replication_rows, + ) + + class TypingStream(Stream): TypingStreamRow = namedtuple( "TypingStreamRow", ("room_id", "user_ids") # str # list(str) diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html new file mode 100644 index 0000000000..b751359bdf --- /dev/null +++ b/synapse/res/templates/account_previously_renewed.html @@ -0,0 +1 @@ +<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html> diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html index 894da030af..e8c0f52f05 100644 --- a/synapse/res/templates/account_renewed.html +++ b/synapse/res/templates/account_renewed.html @@ -1 +1 @@ -<html><body>Your account has been successfully renewed.</body><html> +<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html> diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index d0cf121743..f289ffe3d0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -37,9 +37,11 @@ from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: + from synapse.api.auth import Auth + from synapse.handlers.pagination import PaginationHandler + from synapse.handlers.room import RoomShutdownHandler from synapse.server import HomeServer - logger = logging.getLogger(__name__) @@ -146,50 +148,14 @@ class DeleteRoomRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) - - content = parse_json_object_from_request(request) - - block = content.get("block", False) - if not isinstance(block, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'block' must be a boolean, if given", - Codes.BAD_JSON, - ) - - purge = content.get("purge", True) - if not isinstance(purge, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'purge' must be a boolean, if given", - Codes.BAD_JSON, - ) - - force_purge = content.get("force_purge", False) - if not isinstance(force_purge, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'force_purge' must be a boolean, if given", - Codes.BAD_JSON, - ) - - ret = await self.room_shutdown_handler.shutdown_room( - room_id=room_id, - new_room_user_id=content.get("new_room_user_id"), - new_room_name=content.get("room_name"), - message=content.get("message"), - requester_user_id=requester.user.to_string(), - block=block, + return await _delete_room( + request, + room_id, + self.auth, + self.room_shutdown_handler, + self.pagination_handler, ) - # Purge room - if purge: - await self.pagination_handler.purge_room(room_id, force=force_purge) - - return (200, ret) - class ListRoomRestServlet(RestServlet): """ @@ -282,7 +248,22 @@ class ListRoomRestServlet(RestServlet): class RoomRestServlet(RestServlet): - """Get room details. + """Manage a room. + + On GET : Get details of a room. + + On DELETE : Delete a room from server. + + It is a combination and improvement of shutdown and purge room. + + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto- + joined to the new room. + + If 'purge' is true, it will remove all traces of a room from the database. TODO: Add on_POST to allow room creation without joining the room """ @@ -293,6 +274,8 @@ class RoomRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() async def on_GET( self, request: SynapseRequest, room_id: str @@ -308,6 +291,17 @@ class RoomRestServlet(RestServlet): return (200, ret) + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + return await _delete_room( + request, + room_id, + self.auth, + self.room_shutdown_handler, + self.pagination_handler, + ) + class RoomMembersRestServlet(RestServlet): """ @@ -694,3 +688,55 @@ class RoomEventContextServlet(RestServlet): ) return 200, results + + +async def _delete_room( + request: SynapseRequest, + room_id: str, + auth: "Auth", + room_shutdown_handler: "RoomShutdownHandler", + pagination_handler: "PaginationHandler", +) -> Tuple[int, JsonDict]: + requester = await auth.get_user_by_req(request) + await assert_user_is_admin(auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + ret = await room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + ) + + # Purge room + if purge: + await pagination_handler.purge_room(room_id, force=force_purge) + + return (200, ret) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index edda7861fa..8c9d21d3ea 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -14,6 +14,7 @@ import hashlib import hmac import logging +import secrets from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple @@ -375,7 +376,7 @@ class UserRegisterServlet(RestServlet): """ self._clear_old_nonces() - nonce = self.hs.get_secrets().token_hex(64) + nonce = secrets.token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) return 200, {"nonce": nonce} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index c232484f29..2b24fe5aa6 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -35,10 +35,15 @@ class PresenceStatusRestServlet(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() + self._use_presence = hs.config.server.use_presence + async def on_GET(self, request, user_id): requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if not self._use_presence: + return 200, {"presence": "offline"} + if requester.user != user: allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user @@ -80,7 +85,7 @@ class PresenceStatusRestServlet(RestServlet): except Exception: raise SynapseError(400, "Unable to parse state") - if self.hs.config.use_presence: + if self._use_presence: await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3aad15132d..085561d3e9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -39,7 +39,7 @@ from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import canonicalise_email, check_3pid_allowed +from synapse.util.threepids import check_3pid_allowed, validate_email from ._base import client_patterns, interactive_auth_handler @@ -92,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Stored in the database "foo@bar.com" # User requests with "FOO@bar.com" would raise a Not Found error try: - email = canonicalise_email(body["email"]) + email = validate_email(body["email"]) except ValueError as e: raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] @@ -247,7 +247,7 @@ class PasswordRestServlet(RestServlet): # We store all email addresses canonicalised in the DB. # (See add_threepid in synapse/handlers/auth.py) try: - threepid["address"] = canonicalise_email(threepid["address"]) + threepid["address"] = validate_email(threepid["address"]) except ValueError as e: raise SynapseError(400, str(e)) # if using email, we must know about the email they're authing with! @@ -375,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # Otherwise the email will be sent to "FOO@bar.com" and stored as # "foo@bar.com" in database. try: - email = canonicalise_email(body["email"]) + email = validate_email(body["email"]) except ValueError as e: raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 0ad07fb895..2d1ad3d3fb 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -36,24 +36,40 @@ class AccountValidityRenewServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.success_html = hs.config.account_validity.account_renewed_html_content - self.failure_html = hs.config.account_validity.invalid_token_html_content + self.account_renewed_template = ( + hs.config.account_validity.account_validity_account_renewed_template + ) + self.account_previously_renewed_template = ( + hs.config.account_validity.account_validity_account_previously_renewed_template + ) + self.invalid_token_template = ( + hs.config.account_validity.account_validity_invalid_token_template + ) async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = await self.account_activity_handler.renew_account( + ( + token_valid, + token_stale, + expiration_ts, + ) = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) if token_valid: status_code = 200 - response = self.success_html + response = self.account_renewed_template.render(expiration_ts=expiration_ts) + elif token_stale: + status_code = 200 + response = self.account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) else: status_code = 404 - response = self.failure_html + response = self.invalid_token_template.render(expiration_ts=expiration_ts) respond_with_html(request, status_code, response) @@ -71,10 +87,12 @@ class AccountValiditySendMailServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.account_validity = self.hs.config.account_validity + self.account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) async def on_POST(self, request): - if not self.account_validity.renew_by_email_enabled: + if not self.account_validity_renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b26aad7b34..a30a5df1b1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -30,7 +30,7 @@ from synapse.api.errors import ( ) from synapse.config import ConfigError from synapse.config.captcha import CaptchaConfig -from synapse.config.consent_config import ConsentConfig +from synapse.config.consent import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig @@ -49,7 +49,11 @@ from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import canonicalise_email, check_3pid_allowed +from synapse.util.threepids import ( + canonicalise_email, + check_3pid_allowed, + validate_email, +) from ._base import client_patterns, interactive_auth_handler @@ -111,7 +115,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # (See on_POST in EmailThreepidRequestTokenRestServlet # in synapse/rest/client/v2_alpha/account.py) try: - email = canonicalise_email(body["email"]) + email = validate_email(body["email"]) except ValueError as e: raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index c4550d3cf0..b19cd8afc5 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -32,14 +32,6 @@ TEMPLATE_LANGUAGE = "en" logger = logging.getLogger(__name__) -# use hmac.compare_digest if we have it (python 2.7.7), else just use equality -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - - def compare_digest(a, b): - return a == b - class ConsentResource(DirectServeHtmlResource): """A twisted Resource to display a privacy policy and gather consent to it @@ -209,5 +201,5 @@ class ConsentResource(DirectServeHtmlResource): .encode("ascii") ) - if not compare_digest(want_mac, userhmac): + if not hmac.compare_digest(want_mac, userhmac): raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index c57ac22e58..f648678b09 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -144,7 +144,7 @@ class RemoteKey(DirectServeJsonResource): # Note that the value is unused. cache_misses = {} # type: Dict[str, Dict[str, int]] - for (server_name, key_id, from_server), results in cached.items(): + for (server_name, key_id, _), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: @@ -206,7 +206,7 @@ class RemoteKey(DirectServeJsonResource): # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: - for ts_added, result in results: + for _, result in results: # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 4088e7a059..09531ebf54 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -21,7 +21,7 @@ from typing import Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": +def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]: """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 80f017a4dd..024a105bf2 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -51,8 +51,6 @@ class UploadResource(DirectServeJsonResource): async def _async_render_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) - # TODO: The checks here are a bit late. The content will have - # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") if content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index e5634f9679..488b97b32e 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -61,6 +61,15 @@ class NewUserConsentResource(DirectServeHtmlResource): self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) return + # It should be impossible to get here without having first been through + # the pick-a-username step, which ensures chosen_localpart gets set. + if not session.chosen_localpart: + logger.warning("Session has no user name selected") + self._sso_handler.render_error( + request, "no_user", "No user name has been selected.", code=400 + ) + return + user_id = UserID(session.chosen_localpart, self._server_name) user_profile = { "display_name": session.display_name, diff --git a/synapse/secrets.py b/synapse/secrets.py deleted file mode 100644 index bf829251fd..0000000000 --- a/synapse/secrets.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Injectable secrets module for Synapse. - -See https://docs.python.org/3/library/secrets.html#module-secrets for the API -used in Python 3.6, and the API emulated in Python 2.7. -""" -import sys - -# secrets is available since python 3.6 -if sys.version_info[0:2] >= (3, 6): - import secrets - - class Secrets: - def token_bytes(self, nbytes: int = 32) -> bytes: - return secrets.token_bytes(nbytes) - - def token_hex(self, nbytes: int = 32) -> str: - return secrets.token_hex(nbytes) - - -else: - import binascii - import os - - class Secrets: - def token_bytes(self, nbytes: int = 32) -> bytes: - return os.urandom(nbytes) - - def token_hex(self, nbytes: int = 32) -> str: - return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii") diff --git a/synapse/server.py b/synapse/server.py index 42d2fad8e8..2337d2d9b4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -70,13 +70,14 @@ from synapse.handlers.acme import AcmeHandler from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator -from synapse.handlers.cas_handler import CasHandler +from synapse.handlers.cas import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.directory import DirectoryHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler +from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler @@ -125,7 +126,6 @@ from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, ) -from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import ( @@ -145,8 +145,8 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: from txredisapi import RedisProtocol - from synapse.handlers.oidc_handler import OidcHandler - from synapse.handlers.saml_handler import SamlHandler + from synapse.handlers.oidc import OidcHandler + from synapse.handlers.saml import SamlHandler T = TypeVar("T", bound=Callable[..., Any]) @@ -286,6 +286,14 @@ class HomeServer(metaclass=abc.ABCMeta): if self.config.run_background_tasks: self.setup_background_tasks() + def start_listening(self) -> None: + """Start the HTTP, manhole, metrics, etc listeners + + Does nothing in this base class; overridden in derived classes to start the + appropriate listeners. + """ + pass + def setup_background_tasks(self) -> None: """ Some handlers have side effects on instantiation (like registering @@ -417,10 +425,10 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_presence_handler(self) -> BasePresenceHandler: - if self.config.worker_app: - return WorkerPresenceHandler(self) - else: + if self.get_instance_name() in self.config.worker.writers.presence: return PresenceHandler(self) + else: + return WorkerPresenceHandler(self) @cache_in_self def get_typing_writer_handler(self) -> TypingWriterHandler: @@ -633,10 +641,6 @@ class HomeServer(metaclass=abc.ABCMeta): return GroupAttestionRenewer(self) @cache_in_self - def get_secrets(self) -> Secrets: - return Secrets() - - @cache_in_self def get_stats_handler(self) -> StatsHandler: return StatsHandler(self) @@ -696,13 +700,13 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_saml_handler(self) -> "SamlHandler": - from synapse.handlers.saml_handler import SamlHandler + from synapse.handlers.saml import SamlHandler return SamlHandler(self) @cache_in_self def get_oidc_handler(self) -> "OidcHandler": - from synapse.handlers.oidc_handler import OidcHandler + from synapse.handlers.oidc import OidcHandler return OidcHandler(self) @@ -747,6 +751,10 @@ class HomeServer(metaclass=abc.ABCMeta): return SpaceSummaryHandler(self) @cache_in_self + def get_event_auth_handler(self) -> EventAuthHandler: + return EventAuthHandler(self) + + @cache_in_self def get_external_cache(self) -> ExternalCache: return ExternalCache(self) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c7ee731154..b3bd92d37c 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -19,6 +19,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, DefaultDict, Dict, FrozenSet, @@ -46,7 +47,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, StateMap +from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 32671ddbde..008644cd98 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,6 +18,7 @@ import logging from typing import ( Any, Callable, + Collection, Dict, Generator, Iterable, @@ -37,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import Collection, MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 56dd3a4861..6b68d8720c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,13 +16,13 @@ import logging import random from abc import ABCMeta -from typing import TYPE_CHECKING, Any, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.storage.types import Connection -from synapse.types import Collection, StreamToken, get_domain_from_id +from synapse.types import StreamToken, get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: @@ -114,7 +114,7 @@ def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: db_content = db_content.tobytes() # Decode it to a Unicode string before feeding it to the JSON decoder, since - # Python 3.5 does not support deserializing bytes. + # it only supports handling strings if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 9a6d2b21f9..a761ad603b 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -20,6 +20,7 @@ from time import monotonic as monotonic_time from typing import ( Any, Callable, + Collection, Dict, Iterable, Iterator, @@ -48,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.types import Collection # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 @@ -171,10 +171,7 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. -# -# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so -# that mypy sees the type but the runtime python doesn't. -_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] +_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] R = TypeVar("R") @@ -221,7 +218,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any): + def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -233,7 +230,7 @@ class LoggingTransaction: self.after_callbacks.append((callback, args, kwargs)) def call_on_exception( - self, callback: "Callable[..., None]", *args: Any, **kwargs: Any + self, callback: Callable[..., None], *args: Any, **kwargs: Any ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that @@ -485,7 +482,7 @@ class DatabasePool: desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], - func: "Callable[..., R]", + func: Callable[..., R], *args: Any, **kwargs: Any, ) -> R: @@ -618,7 +615,7 @@ class DatabasePool: async def runInteraction( self, desc: str, - func: "Callable[..., R]", + func: Callable[..., R], *args: Any, db_autocommit: bool = False, **kwargs: Any, @@ -678,7 +675,7 @@ class DatabasePool: async def runWithConnection( self, - func: "Callable[..., R]", + func: Callable[..., R], *args: Any, db_autocommit: bool = False, **kwargs: Any, @@ -718,7 +715,9 @@ class DatabasePool: # pool). assert not self.engine.in_transaction(conn) - with LoggingContext("runWithConnection", parent_context) as context: + with LoggingContext( + str(curr_context), parent_context=parent_context + ) as context: sched_duration_sec = monotonic_time() - start_time sql_scheduling_timer.observe(sched_duration_sec) context.add_database_scheduled(sched_duration_sec) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5c50f5f950..49c7606d51 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -17,7 +17,6 @@ import logging from typing import List, Optional, Tuple -from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import DatabasePool from synapse.storage.databases.main.stats import UserSortOrder @@ -51,7 +50,7 @@ from .media_repository import MediaRepositoryStore from .metrics import ServerMetricsStore from .monthly_active_users import MonthlyActiveUsersStore from .openid import OpenIdStore -from .presence import PresenceStore, UserPresenceState +from .presence import PresenceStore from .profile import ProfileStore from .purge_events import PurgeEventsStore from .push_rule import PushRuleStore @@ -126,9 +125,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" - ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) @@ -177,21 +173,6 @@ class DataStore( super().__init__(database, db_conn, hs) - self._presence_on_startup = self._get_active_presence(db_conn) - - presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( - db_conn, - "presence_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._presence_id_gen.get_current_token(), - ) - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", - min_presence_val, - prefilled_cache=presence_cache_prefill, - ) - device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max @@ -238,32 +219,6 @@ class DataStore( def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - def take_presence_startup_info(self): - active_on_startup = self._presence_on_startup - self._presence_on_startup = None - return active_on_startup - - def _get_active_presence(self, db_conn): - """Fetch non-offline presence from the database so that we can register - the appropriate time outs. - """ - - sql = ( - "SELECT user_id, state, last_active_ts, last_federation_update_ts," - " last_user_sync_ts, status_msg, currently_active FROM presence_stream" - " WHERE state != ?" - ) - - txn = db_conn.cursor() - txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.db_pool.cursor_to_dict(txn) - txn.close() - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] - async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b204875580..c9346de316 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # limitations under the License. import abc import logging -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -31,7 +31,7 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key +from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -717,7 +717,15 @@ class DeviceWorkerStore(SQLBaseStore): keyvalues={"user_id": user_id}, values={}, insertion_values={"added_ts": self._clock.time_msec()}, - desc="make_remote_user_device_cache_as_stale", + desc="mark_remote_user_device_cache_as_stale", + ) + + async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None: + # Remove the database entry that says we need to resync devices, after a resync + await self.db_pool.simple_delete( + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, + desc="mark_remote_user_device_cache_as_valid", ) async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: @@ -1289,15 +1297,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - # If we're replacing the remote user's device list cache presumably - # we've done a full resync, so we remove the entry that says we need - # to resync - self.db_pool.simple_delete_txn( - txn, - table="device_lists_remote_resync", - keyvalues={"user_id": user_id}, - ) - async def add_device_change_to_streams( self, user_id: str, device_ids: Collection[str], hosts: List[str] ): diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 32ce70a396..ff81d5cd17 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, Iterable, List, Set, Tuple +from typing import Collection, Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError from synapse.events import EventBase @@ -25,7 +25,6 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor -from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index bed4326d11..fd25c8112d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -170,7 +170,7 @@ class PersistEventsStore: ) async with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(events_and_contexts, stream_orderings): + for (event, _), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream await self.db_pool.runInteraction( @@ -297,7 +297,7 @@ class PersistEventsStore: txn.execute(sql + clause, args) to_recursively_check = [] - for event_id, prev_event_id, metadata, rejected in txn: + for _, prev_event_id, metadata, rejected in txn: if prev_event_id in existing_prevs: continue @@ -1127,7 +1127,7 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn, new_forward_extremities, max_stream_order ): - for room_id, new_extrem in new_forward_extremities.items(): + for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) @@ -1378,24 +1378,28 @@ class PersistEventsStore: ], ) - for event, _ in events_and_contexts: - if not event.internal_metadata.is_redacted(): - # If we're persisting an unredacted event we go and ensure - # that we mark any redactions that reference this event as - # requiring censoring. - self.db_pool.simple_update_txn( - txn, - table="redactions", - keyvalues={"redacts": event.event_id}, - updatevalues={"have_censored": False}, + # If we're persisting an unredacted event we go and ensure + # that we mark any redactions that reference this event as + # requiring censoring. + sql = "UPDATE redactions SET have_censored = ? WHERE redacts = ?" + txn.execute_batch( + sql, + ( + ( + False, + event.event_id, ) + for event, _ in events_and_contexts + if not event.internal_metadata.is_redacted() + ), + ) state_events_and_contexts = [ ec for ec in events_and_contexts if ec[0].is_state() ] state_values = [] - for event, context in state_events_and_contexts: + for event, _ in state_events_and_contexts: vals = { "event_id": event.event_id, "room_id": event.room_id, @@ -1464,7 +1468,7 @@ class PersistEventsStore: # nothing to do here return - for event, context in events_and_contexts: + for event, _ in events_and_contexts: if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. @@ -1881,20 +1885,28 @@ class PersistEventsStore: ), ) - for event, _ in events_and_contexts: - user_ids = self.db_pool.simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcol="user_id", - ) + room_to_event_ids = {} # type: Dict[str, List[str]] + for e, _ in events_and_contexts: + room_to_event_ids.setdefault(e.room_id, []).append(e.event_id) - for uid in user_ids: - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid), + for room_id, event_ids in room_to_event_ids.items(): + rows = self.db_pool.simple_select_many_txn( + txn, + table="event_push_actions_staging", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("user_id",), ) + user_ids = {row["user_id"] for row in rows} + + for user_id in user_ids: + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + # Now we delete the staging area for *all* events that were being # persisted. txn.execute_batch( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 64d70785b8..2c823e09cf 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -15,7 +15,16 @@ import logging import threading from collections import namedtuple -from typing import Container, Dict, Iterable, List, Optional, Tuple, overload +from typing import ( + Collection, + Container, + Dict, + Iterable, + List, + Optional, + Tuple, + overload, +) from constantly import NamedConstant, Names from typing_extensions import Literal @@ -45,7 +54,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import Collection, JsonDict, get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index b99399ac21..ed20494d4e 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,20 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple -from synapse.api.presence import UserPresenceState +from synapse.api.presence import PresenceState, UserPresenceState +from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.constants import HOUR_IN_MS from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + # How long to keep entries in the users_to_send_full_presence_to db table USERS_TO_SEND_FULL_PRESENCE_TO_ENTRY_LIFETIME_MS = 14 * 24 * HOUR_IN_MS class PresenceStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._can_persist_presence = ( + hs.get_instance_name() in hs.config.worker.writers.presence + ) + + if isinstance(database.engine, PostgresEngine): + self._presence_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="presence_stream", + instance_name=self._instance_name, + tables=[("presence_stream", "instance_name", "stream_id")], + sequence_name="presence_stream_sequence", + writers=hs.config.worker.writers.to_device, + ) + else: + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) + + self._presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( + db_conn, + "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_current_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, + ) + async def update_presence(self, presence_states): + assert self._can_persist_presence + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) @@ -61,6 +114,7 @@ class PresenceStore(SQLBaseStore): "last_user_sync_ts": state.last_user_sync_ts, "status_msg": state.status_msg, "currently_active": state.currently_active, + "instance_name": self._instance_name, } for stream_id, state in zip(stream_orderings, presence_states) ], @@ -293,3 +347,37 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() + + def _get_active_presence(self, db_conn: Connection): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.db_pool.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + for row in rows: + self.presence_stream_cache.entity_has_changed(row.user_id, token) + self._get_presence_for_user.invalidate((row.user_id,)) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 833214b7e0..6e5ee557d2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -91,13 +91,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): id_column=None, ) - self._account_validity = hs.config.account_validity - if hs.config.run_background_tasks and self._account_validity.enabled: - self._clock.call_later( - 0.0, - self._set_expiration_date_when_missing, + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) + self._account_validity_period = None + self._account_validity_startup_job_max_delta = None + if self._account_validity_enabled: + self._account_validity_period = ( + hs.config.account_validity.account_validity_period + ) + self._account_validity_startup_job_max_delta = ( + hs.config.account_validity.account_validity_startup_job_max_delta ) + if hs.config.run_background_tasks: + self._clock.call_later( + 0.0, + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: self._clock.looping_call( @@ -194,6 +206,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts: int, email_sent: bool, renewal_token: Optional[str] = None, + token_used_ts: Optional[int] = None, ) -> None: """Updates the account validity properties of the given account, with the given values. @@ -207,6 +220,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): period. renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. + token_used_ts: A timestamp of when the current token was used to renew + the account. """ def set_account_validity_for_user_txn(txn): @@ -218,6 +233,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "expiration_ts_ms": expiration_ts, "email_sent": email_sent, "renewal_token": renewal_token, + "token_used_ts_ms": token_used_ts, }, ) self._invalidate_cache_and_stream( @@ -231,7 +247,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: - """Defines a renewal token for a given user. + """Defines a renewal token for a given user, and clears the token_used timestamp. Args: user_id: ID of the user to set the renewal token for. @@ -244,26 +260,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, + updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None}, desc="set_renewal_token_for_user", ) - async def get_user_from_renewal_token(self, renewal_token: str) -> str: - """Get a user ID from a renewal token. + async def get_user_from_renewal_token( + self, renewal_token: str + ) -> Tuple[str, int, Optional[int]]: + """Get a user ID and renewal status from a renewal token. Args: renewal_token: The renewal token to perform the lookup with. Returns: - The ID of the user to which the token belongs. + A tuple of containing the following values: + * The ID of a user to which the token belongs. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. + * An optional int representing the timestamp of when the user renewed their + account timestamp as milliseconds since the epoch. None if the account + has not been renewed using the current token yet. """ - return await self.db_pool.simple_select_one_onecol( + ret_dict = await self.db_pool.simple_select_one( table="account_validity", keyvalues={"renewal_token": renewal_token}, - retcol="user_id", + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], desc="get_user_from_renewal_token", ) + return ( + ret_dict["user_id"], + ret_dict["expiration_ts_ms"], + ret_dict["token_used_ts_ms"], + ) + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. @@ -302,7 +332,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity.renew_at, + self.config.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: @@ -964,11 +994,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period + expiration_ts = now_ms + self._account_validity_period if use_delta: expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) @@ -1412,7 +1442,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): except self.database_engine.module.IntegrityError: raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - if self._account_validity.enabled: + if self._account_validity_enabled: self.set_expiration_date_for_user_txn(txn, user_id) if create_profile_with_displayname: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ef5587f87a..2a8532f8c1 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) + +import attr from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -33,7 +46,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, PersistedEventPosition, get_domain_from_id +from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -53,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + # Used by `_get_joined_hosts` to ensure only one thing mutates the cache + # at a time. Keyed by room_id. + self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + # Is the current_state_events.membership up to date? Or is the # background update still running? self._current_state_events_membership_up_to_date = False @@ -173,6 +190,33 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id, Membership.JOIN)) return [r[0] for r in txn] + @cached(max_entries=100000, iterable=True) + async def get_users_in_room_with_profiles( + self, room_id: str + ) -> Dict[str, ProfileInfo]: + """Get a mapping from user ID to profile information for all users in a given room. + + Args: + room_id: The ID of the room to retrieve the users of. + + Returns: + A mapping from user ID to ProfileInfo. + """ + + def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: + sql = """ + SELECT user_id, display_name, avatar_url FROM room_memberships + WHERE room_id = ? AND membership = ? + """ + txn.execute(sql, (room_id, Membership.JOIN)) + + return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} + + return await self.db_pool.runInteraction( + "get_users_in_room_with_profiles", + _get_users_in_room_with_profiles, + ) + @cached(max_entries=100000) async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room @@ -703,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id, state_group, current_state_ids, state_entry - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. + self, + room_id: str, + state_group: int, + current_state_ids: StateMap[str], + state_entry: "_StateCacheEntry", + ) -> FrozenSet[str]: + # We don't use `state_group`, its there so that we can cache based on + # it. However, its important that its never None, since two + # current_state's with a state_group of None are likely to be different. + # + # The `state_group` must match the `state_entry.state_group` (if not None). assert state_group is not None - + assert state_entry.state_group is None or state_entry.state_group == state_group + + # We use a secondary cache of previous work to allow us to build up the + # joined hosts for the given state group based on previous state groups. + # + # We cache one object per room containing the results of the last state + # group we got joined hosts for. The idea is that generally + # `get_joined_hosts` is called with the "current" state group for the + # room, and so consecutive calls will be for consecutive state groups + # which point to the previous state group. cache = await self._get_joined_hosts_cache(room_id) - return await cache.get_destinations(state_entry) + + # If the state group in the cache matches, we already have the data we need. + if state_entry.state_group == cache.state_group: + return frozenset(cache.hosts_to_joined_users) + + # Since we'll mutate the cache we need to lock. + with (await self._joined_host_linearizer.queue(room_id)): + if state_entry.state_group == cache.state_group: + # Same state group, so nothing to do. We've already checked for + # this above, but the cache may have changed while waiting on + # the lock. + pass + elif state_entry.prev_group == cache.state_group: + # The cached work is for the previous state group, so we work out + # the delta. + for (typ, state_key), event_id in state_entry.delta_ids.items(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = cache.hosts_to_joined_users.setdefault(host, set()) + + event = await self.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + cache.hosts_to_joined_users.pop(host, None) + else: + # The cache doesn't match the state group or prev state group, + # so we calculate the result from first principles. + joined_users = await self.get_joined_users_from_state( + room_id, state_entry + ) + + cache.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + cache.state_group = state_entry.state_group + else: + cache.state_group = object() + + return frozenset(cache.hosts_to_joined_users) @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": - return _JoinedHostsCache(self, room_id) + return _JoinedHostsCache() @cached(num_args=2) async def did_forget(self, user_id: str, room_id: str) -> bool: @@ -1025,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): await self.db_pool.runInteraction("forget_membership", f) +@attr.s(slots=True) class _JoinedHostsCache: - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ - - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id - - self.hosts_to_joined_users = {} - - self.state_group = object() + """The cached data used by the `_get_joined_hosts_cache`.""" - self.linearizer = Linearizer("_JoinedHostsCache") + # Dict of host to the set of their users in the room at the state group. + hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) - self._len = 0 - - async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: - """Get set of destinations for a state entry - - Args: - state_entry - - Returns: - The destinations as a set. - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (await self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in state_entry.delta_ids.items(): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = await self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = await self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) - return frozenset(self.hosts_to_joined_users) + # The state group `hosts_to_joined_users` is derived from. Will be an object + # if the instance is newly created or if the state is not based on a state + # group. (An object is used as a sentinel value to ensure that it never is + # equal to anything else). + state_group = attr.ib(type=Union[object, int], factory=object) def __len__(self): - return self._len + return sum(len(v) for v in self.hosts_to_joined_users.values()) diff --git a/synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql b/synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql new file mode 100644 index 0000000000..4836dac16e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track when users renew their account using the value of the 'renewal_token' column. +-- This field should be set to NULL after a fresh token is generated. +ALTER TABLE account_validity ADD token_used_ts_ms BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql new file mode 100644 index 0000000000..b6ba0bda1a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to specify which instance wrote the row. Historic rows have +-- `NULL`, which indicates that the master instance wrote them. +ALTER TABLE presence_stream ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres new file mode 100644 index 0000000000..02b182adf9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS presence_stream_sequence; + +SELECT setval('presence_stream_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM presence_stream +)); diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 0276f30656..6480d5a9f5 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -15,7 +15,7 @@ import logging import re from collections import namedtuple -from typing import List, Optional, Set +from typing import Collection, List, Optional, Set from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -23,7 +23,6 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import Collection logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index db5ce4ea01..7581c7d3ff 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -37,7 +37,7 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple from twisted.internet import defer @@ -53,7 +53,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index b28ca61f80..82335e7a9d 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from collections import namedtuple -from typing import Dict, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -295,33 +295,37 @@ class TransactionStore(TransactionWorkerStore): }, ) - async def bulk_store_destination_rooms_entries( - self, room_and_destination_to_ordering: Dict[Tuple[str, str], int] - ): + async def store_destination_rooms_entries( + self, + destinations: Iterable[str], + room_id: str, + stream_ordering: int, + ) -> None: """ - Updates or creates `destination_rooms` entries for a number of events. + Updates or creates `destination_rooms` entries in batch for a single event. Args: - room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id + destinations: list of destinations + room_id: the room_id of the event + stream_ordering: the stream_ordering of the event """ await self.db_pool.simple_upsert_many( table="destinations", key_names=("destination",), - key_values={(d,) for _, d in room_and_destination_to_ordering.keys()}, + key_values=[(d,) for d in destinations], value_names=[], value_values=[], desc="store_destination_rooms_entries_dests", ) + rows = [(destination, room_id) for destination in destinations] await self.db_pool.simple_upsert_many( table="destination_rooms", - key_names=("room_id", "destination"), - key_values=list(room_and_destination_to_ordering.keys()), + key_names=("destination", "room_id"), + key_values=rows, value_names=["stream_ordering"], - value_values=[ - (stream_id,) for stream_id in room_and_destination_to_ordering.values() - ], + value_values=[(stream_ordering,)] * len(rows), desc="store_destination_rooms_entries_rooms", ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 87e040b014..33dc752d8f 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import deque, namedtuple -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter, Histogram @@ -32,7 +32,6 @@ from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( - Collection, PersistedEventPosition, RoomStreamToken, StateMap, diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 05a9355974..7a2cbee426 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -17,7 +17,7 @@ import logging import os import re from collections import Counter -from typing import Generator, Iterable, List, Optional, TextIO, Tuple +from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple import attr from typing_extensions import Counter as CounterType @@ -27,7 +27,6 @@ from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.types import Cursor -from synapse.types import Collection logger = logging.getLogger(__name__) diff --git a/synapse/types.py b/synapse/types.py index 21654ae686..e52cd7ffd4 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -15,13 +15,11 @@ import abc import re import string -import sys from collections import namedtuple from typing import ( TYPE_CHECKING, Any, Dict, - Iterable, Mapping, MutableMapping, Optional, @@ -50,18 +48,6 @@ if TYPE_CHECKING: from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore -# define a version of typing.Collection that works on python 3.5 -if sys.version_info[:3] >= (3, 6, 0): - from typing import Collection -else: - from typing import Container, Sized - - T_co = TypeVar("T_co", covariant=True) - - class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore - __slots__ = () - - # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") @@ -213,9 +199,8 @@ def get_localpart_from_id(string): DS = TypeVar("DS", bound="DomainSpecificString") -class DomainSpecificString( - namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta -): +@attr.s(slots=True, frozen=True, repr=False) +class DomainSpecificString(metaclass=abc.ABCMeta): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -227,11 +212,8 @@ class DomainSpecificString( SIGIL = abc.abstractproperty() # type: str # type: ignore - # Deny iteration because it will bite you if you try to create a singleton - # set by: - # users = set(user) - def __iter__(self): - raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) + localpart = attr.ib(type=str) + domain = attr.ib(type=str) # Because this class is a namedtuple of strings and booleans, it is deeply # immutable. @@ -286,30 +268,35 @@ class DomainSpecificString( __repr__ = to_string +@attr.s(slots=True, frozen=True, repr=False) class UserID(DomainSpecificString): """Structure representing a user ID.""" SIGIL = "@" +@attr.s(slots=True, frozen=True, repr=False) class RoomAlias(DomainSpecificString): """Structure representing a room name.""" SIGIL = "#" +@attr.s(slots=True, frozen=True, repr=False) class RoomID(DomainSpecificString): """Structure representing a room id. """ SIGIL = "!" +@attr.s(slots=True, frozen=True, repr=False) class EventID(DomainSpecificString): """Structure representing an event id. """ SIGIL = "$" +@attr.s(slots=True, frozen=True, repr=False) class GroupID(DomainSpecificString): """Structure representing a group ID.""" diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a21d34fcb4..10b0ec6b75 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -17,8 +17,10 @@ from functools import wraps from typing import ( Any, Callable, + Collection, Generic, Iterable, + List, Optional, Type, TypeVar, @@ -57,13 +59,56 @@ class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] def __init__( - self, prev_node, next_node, key, value, callbacks: Optional[set] = None + self, + prev_node, + next_node, + key, + value, + callbacks: Collection[Callable[[], None]] = (), ): self.prev_node = prev_node self.next_node = next_node self.key = key self.value = value - self.callbacks = callbacks or set() + + # Set of callbacks to run when the node gets deleted. We store as a list + # rather than a set to keep memory usage down (and since we expect few + # entries per node, the performance of checking for duplication in a + # list vs using a set is negligible). + # + # Note that we store this as an optional list to keep the memory + # footprint down. Storing `None` is free as its a singleton, while empty + # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive + # thing and used sets). + self.callbacks = None # type: Optional[List[Callable[[], None]]] + + self.add_callbacks(callbacks) + + def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None: + """Add to stored list of callbacks, removing duplicates.""" + + if not callbacks: + return + + if not self.callbacks: + self.callbacks = [] + + for callback in callbacks: + if callback not in self.callbacks: + self.callbacks.append(callback) + + def run_and_clear_callbacks(self) -> None: + """Run all callbacks and clear the stored list of callbacks. Used when + the node is being deleted. + """ + + if not self.callbacks: + return + + for callback in self.callbacks: + callback() + + self.callbacks = None class LruCache(Generic[KT, VT]): @@ -177,10 +222,10 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) - def add_node(key, value, callbacks: Optional[set] = None): + def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): prev_node = list_root next_node = prev_node.next_node - node = _Node(prev_node, next_node, key, value, callbacks or set()) + node = _Node(prev_node, next_node, key, value, callbacks) prev_node.next_node = node next_node.prev_node = node cache[key] = node @@ -211,16 +256,15 @@ class LruCache(Generic[KT, VT]): deleted_len = size_callback(node.value) cached_cache_len[0] -= deleted_len - for cb in node.callbacks: - cb() - node.callbacks.clear() + node.run_and_clear_callbacks() + return deleted_len @overload def cache_get( key: KT, default: Literal[None] = None, - callbacks: Iterable[Callable[[], None]] = ..., + callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., ) -> Optional[VT]: ... @@ -229,7 +273,7 @@ class LruCache(Generic[KT, VT]): def cache_get( key: KT, default: T, - callbacks: Iterable[Callable[[], None]] = ..., + callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., ) -> Union[T, VT]: ... @@ -238,13 +282,13 @@ class LruCache(Generic[KT, VT]): def cache_get( key: KT, default: Optional[T] = None, - callbacks: Iterable[Callable[[], None]] = (), + callbacks: Collection[Callable[[], None]] = (), update_metrics: bool = True, ): node = cache.get(key, None) if node is not None: move_node_to_front(node) - node.callbacks.update(callbacks) + node.add_callbacks(callbacks) if update_metrics and metrics: metrics.inc_hits() return node.value @@ -260,10 +304,8 @@ class LruCache(Generic[KT, VT]): # We sometimes store large objects, e.g. dicts, which cause # the inequality check to take a long time. So let's only do # the check if we have some callbacks to call. - if node.callbacks and value != node.value: - for cb in node.callbacks: - cb() - node.callbacks.clear() + if value != node.value: + node.run_and_clear_callbacks() # We don't bother to protect this by value != node.value as # generally size_callback will be cheap compared with equality @@ -273,7 +315,7 @@ class LruCache(Generic[KT, VT]): cached_cache_len[0] -= size_callback(node.value) cached_cache_len[0] += size_callback(value) - node.callbacks.update(callbacks) + node.add_callbacks(callbacks) move_node_to_front(node) node.value = value @@ -326,8 +368,7 @@ class LruCache(Generic[KT, VT]): list_root.next_node = list_root list_root.prev_node = list_root for node in cache.values(): - for cb in node.callbacks: - cb() + node.run_and_clear_callbacks() cache.clear() if size_callback: cached_cache_len[0] = 0 diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 2529845c9e..25ea1bcc91 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -110,7 +110,7 @@ class ResponseCache(Generic[T]): return result.observe() def wrap( - self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any + self, key: T, callback: Callable[..., Any], *args: Any, **kwargs: Any ) -> defer.Deferred: """Wrap together a *get* and *set* call, taking care of logcontexts diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 0469e7d120..e81e468899 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,11 +14,10 @@ import logging import math -from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union +from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union from sortedcontainers import SortedDict -from synapse.types import Collection from synapse.util import caches logger = logging.getLogger(__name__) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 6f73b1d56d..abfdc29832 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -15,6 +15,7 @@ import heapq from itertools import islice from typing import ( + Collection, Dict, Generator, Iterable, @@ -26,8 +27,6 @@ from typing import ( TypeVar, ) -from synapse.types import Collection - T = TypeVar("T") diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 6ed7179e8c..6d14351bd2 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -104,7 +104,13 @@ class Measure: "start", ] - def __init__(self, clock, name): + def __init__(self, clock, name: str): + """ + Args: + clock: A n object with a "time()" method, which returns the current + time in seconds. + name: The name of the metric to report. + """ self.clock = clock self.name = name curr_context = current_context() @@ -117,10 +123,8 @@ class Measure: else: assert isinstance(curr_context, LoggingContext) parent_context = curr_context - self._logging_context = LoggingContext( - "Measure[%s]" % (self.name,), parent_context - ) - self.start = None + self._logging_context = LoggingContext(str(curr_context), parent_context) + self.start = None # type: Optional[int] def __enter__(self) -> "Measure": if self.start is not None: diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index c0e6fb9a60..cd82777f80 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -132,6 +132,38 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int] return host, port +def valid_id_server_location(id_server: str) -> bool: + """Check whether an identity server location, such as the one passed as the + `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid. + + A valid identity server location consists of a valid hostname and optional + port number, optionally followed by any number of `/` delimited path + components, without any fragment or query string parts. + + Args: + id_server: identity server location string to validate + + Returns: + True if valid, False otherwise. + """ + + components = id_server.split("/", 1) + + host = components[0] + + try: + parse_and_validate_server_name(host) + except ValueError: + return False + + if len(components) < 2: + # no path + return True + + path = components[1] + return "#" not in path and "?" not in path + + def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: """Parse the given string as an MXC URI diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 281c5be4fb..a1cf1960b0 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -18,6 +18,16 @@ import re logger = logging.getLogger(__name__) +# it's unclear what the maximum length of an email address is. RFC3696 (as corrected +# by errata) says: +# the upper limit on address lengths should normally be considered to be 254. +# +# In practice, mail servers appear to be more tolerant and allow 400 characters +# or so. Let's allow 500, which should be plenty for everyone. +# +MAX_EMAIL_ADDRESS_LENGTH = 500 + + def check_3pid_allowed(hs, medium, address): """Checks whether a given format of 3PID is allowed to be used on this HS @@ -70,3 +80,23 @@ def canonicalise_email(address: str) -> str: raise ValueError("Unable to parse email address") return parts[0].casefold() + "@" + parts[1].lower() + + +def validate_email(address: str) -> str: + """Does some basic validation on an email address. + + Returns the canonicalised email, as returned by `canonicalise_email`. + + Raises a ValueError if the email is invalid. + """ + # First we try canonicalising in case that fails + address = canonicalise_email(address) + + # Email addresses have to be at least 3 characters. + if len(address) < 3: + raise ValueError("Unable to parse email address") + + if len(address) > MAX_EMAIL_ADDRESS_LENGTH: + raise ValueError("Unable to parse email address") + + return address diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py deleted file mode 100644 index 3d45da38ab..0000000000 --- a/tests/app/test_frontend_proxy.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.app.generic_worker import GenericWorkerServer - -from tests.server import make_request -from tests.unittest import HomeserverTestCase - - -class FrontendProxyTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - federation_http_client=None, homeserver_to_use=GenericWorkerServer - ) - - return hs - - def default_config(self): - c = super().default_config() - c["worker_app"] = "synapse.app.frontend_proxy" - - c["worker_listeners"] = [ - { - "type": "http", - "port": 8080, - "bind_addresses": ["0.0.0.0"], - "resources": [{"names": ["client"]}], - } - ] - - return c - - def test_listen_http_with_presence_enabled(self): - """ - When presence is on, the stub servlet will not register. - """ - # Presence is on - self.hs.config.use_presence = True - - # Listen with the config - self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) - - # Grab the resource from the site that was told to listen - self.assertEqual(len(self.reactor.tcpServers), 1) - site = self.reactor.tcpServers[0][1] - - channel = make_request(self.reactor, site, "PUT", "presence/a/status") - - # 400 + unrecognised, because nothing is registered - self.assertEqual(channel.code, 400) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - def test_listen_http_with_presence_disabled(self): - """ - When presence is off, the stub servlet will register. - """ - # Presence is off - self.hs.config.use_presence = False - - # Listen with the config - self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) - - # Grab the resource from the site that was told to listen - self.assertEqual(len(self.reactor.tcpServers), 1) - site = self.reactor.tcpServers[0][1] - - channel = make_request(self.reactor, site, "PUT", "presence/a/status") - - # 401, because the stub servlet still checks authentication - self.assertEqual(channel.code, 401) - self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 0444b26798..b625995d12 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest.mock import Mock -from synapse.handlers.cas_handler import CasResponse +from synapse.handlers.cas import CasResponse from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c7b0975a19..8796af45ed 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -222,7 +222,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - for i in range(3): + for _ in range(3): event = create_invite() self.get_success( self.handler.on_invite_request( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 34d2fc1dfb..a25c89bd5b 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -499,7 +499,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("fetch_error") # Handle code exchange failure - from synapse.handlers.oidc_handler import OidcError + from synapse.handlers.oidc import OidcError self.provider._exchange_code = simple_async_mock( raises=OidcError("invalid_request") @@ -583,7 +583,7 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "foo", "error_description": "bar"}', ) ) - from synapse.handlers.oidc_handler import OidcError + from synapse.handlers.oidc import OidcError exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "foo") @@ -1126,7 +1126,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url: str, ui_auth_session_id: str = "", ) -> str: - from synapse.handlers.oidc_handler import OidcSessionData + from synapse.handlers.oidc import OidcSessionData return self.handler._token_generator.generate_oidc_session_token( state=state, @@ -1152,7 +1152,7 @@ async def _make_callback_with_userinfo( userinfo: the OIDC userinfo dict client_redirect_url: the URL to redirect to on success. """ - from synapse.handlers.oidc_handler import OidcSessionData + from synapse.handlers.oidc import OidcSessionData handler = hs.get_oidc_handler() provider = handler._providers["oidc"] diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 2d12e82897..ce330e79cc 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder +from synapse.federation.sender import FederationSender from synapse.handlers.presence import ( EXTERNAL_PROCESS_EXPIRY, FEDERATION_PING_INTERVAL, @@ -471,6 +472,190 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): self.assertEqual(state.state, PresenceState.OFFLINE) +class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + self.instance_name = hs.get_instance_name() + + self.queue = self.presence_handler.get_federation_queue() + + def test_send_and_get(self): + state1 = UserPresenceState.default("@user1:test") + state2 = UserPresenceState.default("@user2:test") + state3 = UserPresenceState.default("@user3:test") + + prev_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + + now_token = self.queue.get_current_token(self.instance_name) + + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", prev_token, now_token, 10) + ) + + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + + expected_rows = [ + (1, ("dest1", "@user1:test")), + (1, ("dest2", "@user1:test")), + (1, ("dest1", "@user2:test")), + (1, ("dest2", "@user2:test")), + (2, ("dest3", "@user3:test")), + ] + + self.assertCountEqual(rows, expected_rows) + + now_token = self.queue.get_current_token(self.instance_name) + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", upto_token, now_token, 10) + ) + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + self.assertCountEqual(rows, []) + + def test_send_and_get_split(self): + state1 = UserPresenceState.default("@user1:test") + state2 = UserPresenceState.default("@user2:test") + state3 = UserPresenceState.default("@user3:test") + + prev_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + + now_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", prev_token, now_token, 10) + ) + + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + + expected_rows = [ + (1, ("dest1", "@user1:test")), + (1, ("dest2", "@user1:test")), + (1, ("dest1", "@user2:test")), + (1, ("dest2", "@user2:test")), + ] + + self.assertCountEqual(rows, expected_rows) + + now_token = self.queue.get_current_token(self.instance_name) + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", upto_token, now_token, 10) + ) + + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + + expected_rows = [ + (2, ("dest3", "@user3:test")), + ] + + self.assertCountEqual(rows, expected_rows) + + def test_clear_queue_all(self): + state1 = UserPresenceState.default("@user1:test") + state2 = UserPresenceState.default("@user2:test") + state3 = UserPresenceState.default("@user3:test") + + prev_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + + self.reactor.advance(10 * 60 * 1000) + + now_token = self.queue.get_current_token(self.instance_name) + + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", prev_token, now_token, 10) + ) + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + self.assertCountEqual(rows, []) + + prev_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + + now_token = self.queue.get_current_token(self.instance_name) + + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", prev_token, now_token, 10) + ) + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + + expected_rows = [ + (3, ("dest1", "@user1:test")), + (3, ("dest2", "@user1:test")), + (3, ("dest1", "@user2:test")), + (3, ("dest2", "@user2:test")), + (4, ("dest3", "@user3:test")), + ] + + self.assertCountEqual(rows, expected_rows) + + def test_partially_clear_queue(self): + state1 = UserPresenceState.default("@user1:test") + state2 = UserPresenceState.default("@user2:test") + state3 = UserPresenceState.default("@user3:test") + + prev_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + + self.reactor.advance(2 * 60 * 1000) + + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + + self.reactor.advance(4 * 60 * 1000) + + now_token = self.queue.get_current_token(self.instance_name) + + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", prev_token, now_token, 10) + ) + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + + expected_rows = [ + (2, ("dest3", "@user3:test")), + ] + self.assertCountEqual(rows, []) + + prev_token = self.queue.get_current_token(self.instance_name) + + self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + + now_token = self.queue.get_current_token(self.instance_name) + + rows, upto_token, limited = self.get_success( + self.queue.get_replication_rows("master", prev_token, now_token, 10) + ) + self.assertEqual(upto_token, now_token) + self.assertFalse(limited) + + expected_rows = [ + (3, ("dest1", "@user1:test")), + (3, ("dest2", "@user1:test")), + (3, ("dest1", "@user2:test")), + (3, ("dest2", "@user2:test")), + (4, ("dest3", "@user3:test")), + ] + + self.assertCountEqual(rows, expected_rows) + + class PresenceJoinTestCase(unittest.HomeserverTestCase): """Tests remote servers get told about presence of users in the room when they join and when new local users join. @@ -482,10 +667,17 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", federation_http_client=None, federation_sender=Mock() + "server", + federation_http_client=None, + federation_sender=Mock(spec=FederationSender), ) return hs + def default_config(self): + config = super().default_config() + config["send_federation"] = True + return config + def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() @@ -529,9 +721,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Add a new remote server to the room self._add_new_user(room_id, "@alice:server2") - # We shouldn't have sent out any local presence *updates* - self.federation_sender.send_presence.assert_not_called() - # When new server is joined we send it the local users presence states. # We expect to only see user @test2:server, as @test:server is offline # and has a zero last_active_ts @@ -550,7 +739,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.federation_sender.reset_mock() self._add_new_user(room_id, "@bob:server3") - self.federation_sender.send_presence.assert_not_called() self.federation_sender.send_presence_to_destinations.assert_called_once_with( destinations=["server3"], states={expected_state} ) @@ -595,9 +783,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.reactor.pump([0]) # Wait for presence updates to be handled - # We shouldn't have sent out any local presence *updates* - self.federation_sender.send_presence.assert_not_called() - # We expect to only send test2 presence to server2 and server3 expected_state = self.get_success( self.presence_handler.current_state_for_user("@test2:server") diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 9e97185507..ed9a884d76 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -26,6 +26,7 @@ from twisted.web.http import HTTPChannel from synapse.api.errors import RequestSendFailed from synapse.http.matrixfederationclient import ( + MAX_RESPONSE_SIZE, MatrixFederationHttpClient, MatrixFederationRequest, ) @@ -560,3 +561,61 @@ class FederationClientTests(HomeserverTestCase): f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed) + + def test_too_big(self): + """ + Test what happens if a huge response is returned from the remote endpoint. + """ + + test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + + # complete the connection and wire it up to a fake transport + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # that should have made it send the request to the transport + self.assertRegex(transport.value(), b"^GET /foo/bar") + self.assertRegex(transport.value(), b"Host: testserv:8008") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # Send it a huge HTTP response + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Type: application/json\r\n" + b"\r\n" + ) + + self.pump() + + # should still be waiting + self.assertNoResult(test_d) + + sent = 0 + chunk_size = 1024 * 512 + while not test_d.called: + protocol.dataReceived(b"a" * chunk_size) + sent += chunk_size + self.assertLessEqual(sent, MAX_RESPONSE_SIZE) + + self.assertEqual(sent, MAX_RESPONSE_SIZE) + + f = self.failureResultOf(test_d) + self.assertIsInstance(f.value, RequestSendFailed) + + self.assertTrue(transport.disconnecting) diff --git a/tests/http/test_site.py b/tests/http/test_site.py new file mode 100644 index 0000000000..8c13b4f693 --- /dev/null +++ b/tests/http/test_site.py @@ -0,0 +1,83 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet.address import IPv6Address +from twisted.test.proto_helpers import StringTransport + +from synapse.app.homeserver import SynapseHomeServer + +from tests.unittest import HomeserverTestCase + + +class SynapseRequestTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer) + + def test_large_request(self): + """overlarge HTTP requests should be rejected""" + self.hs.start_listening() + + # find the HTTP server which is configured to listen on port 0 + (port, factory, _backlog, interface) = self.reactor.tcpServers[0] + self.assertEqual(interface, "::") + self.assertEqual(port, 0) + + # as a control case, first send a regular request. + + # complete the connection and wire it up to a fake transport + client_address = IPv6Address("TCP", "::1", "2345") + protocol = factory.buildProtocol(client_address) + transport = StringTransport() + protocol.makeConnection(transport) + + protocol.dataReceived( + b"POST / HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"0\r\n" + b"\r\n" + ) + + while not transport.disconnecting: + self.reactor.advance(1) + + # we should get a 404 + self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ") + + # now send an oversized request + protocol = factory.buildProtocol(client_address) + transport = StringTransport() + protocol.makeConnection(transport) + + protocol.dataReceived( + b"POST / HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + # we deliberately send all the data in one big chunk, to ensure that + # twisted isn't buffering the data in the chunked transfer decoder. + # we start with the chunk size, in hex. (We won't actually send this much) + protocol.dataReceived(b"10000000\r\n") + sent = 0 + while not transport.disconnected: + self.assertLess(sent, 0x10000000, "connection did not drop") + protocol.dataReceived(b"\0" * 1024) + sent += 1024 + + # default max upload size is 50M, so it should drop on the next buffer after + # that. + self.assertEqual(sent, 50 * 1024 * 1024 + 1024) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 6cddb95c30..1160716929 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -137,7 +137,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - self.assertTrue(log["request"].startswith("name@")) + self.assertEqual(log["request"], "name") def test_with_request_context(self): """ @@ -164,7 +164,9 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): # Also set the requester to ensure the processing works. request.requester = "@foo:test" - with LoggingContext(parent_context=request.logcontext): + with LoggingContext( + request.get_request_id(), parent_context=request.logcontext + ): logger.info("Hello there, %s!", "wally") log = self.get_log_line() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index c9d04aef29..624bd1b927 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,14 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple -from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol -from twisted.internet.task import LoopingCall -from twisted.web.http import HTTPChannel from twisted.web.resource import Resource -from twisted.web.server import Request, Site from synapse.app.generic_worker import GenericWorkerServer from synapse.http.server import JsonResource @@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import ( ServerReplicationStreamProtocol, ) from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest from tests.server import FakeTransport @@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): client_protocol = client_factory.buildProtocol(None) # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) + channel = self.site.buildProtocol(None) + + # hook into the channel's request factory so that we can keep a record + # of the requests + requests: List[SynapseRequest] = [] + real_request_factory = channel.requestFactory + + def request_factory(*args, **kwargs): + request = real_request_factory(*args, **kwargs) + requests.append(request) + return request + + channel.requestFactory = request_factory # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server_to_client_transport.loseConnection() client_to_server_transport.loseConnection() - return channel.request + # there should have been exactly one request + self.assertEqual(len(requests), 1) + + return requests[0] def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str @@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", + max_request_body_size=4096, + reactor=self.reactor, ) if worker_hs.config.redis.redis_enabled: @@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): client_protocol = client_factory.buildProtocol(None) # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) + channel = self._hs_to_site[hs].buildProtocol(None) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler): self.received_rdata_rows.append((stream_name, token, r)) -class _PushHTTPChannel(HTTPChannel): - """A HTTPChannel that wraps pull producers to push producers. - - This is a hack to get around the fact that HTTPChannel transparently wraps a - pull producer (which is what Synapse uses to reply to requests) with - `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush` - uses the standard reactor rather than letting us use our test reactor, which - makes it very hard to test. - """ - - def __init__( - self, reactor: IReactorTime, request_factory: Type[Request], site: Site - ): - super().__init__() - self.reactor = reactor - self.requestFactory = request_factory - self.site = site - - self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] - - def registerProducer(self, producer, streaming): - # Convert pull producers to push producer. - if not streaming: - self._pull_to_push_producer = _PullToPushProducer( - self.reactor, producer, self - ) - producer = self._pull_to_push_producer - - super().registerProducer(producer, True) - - def unregisterProducer(self): - if self._pull_to_push_producer: - # We need to manually stop the _PullToPushProducer. - self._pull_to_push_producer.stop() - - def checkPersistence(self, request, version): - """Check whether the connection can be re-used""" - # We hijack this to always say no for ease of wiring stuff up in - # `handle_http_replication_attempt`. - request.responseHeaders.setRawHeaders(b"connection", [b"close"]) - return False - - def requestDone(self, request): - # Store the request for inspection. - self.request = request - super().requestDone(request) - - -class _PullToPushProducer: - """A push producer that wraps a pull producer.""" - - def __init__( - self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer - ): - self._clock = Clock(reactor) - self._producer = producer - self._consumer = consumer - - # While running we use a looping call with a zero delay to call - # resumeProducing on given producer. - self._looping_call = None # type: Optional[LoopingCall] - - # We start writing next reactor tick. - self._start_loop() - - def _start_loop(self): - """Start the looping call to""" - - if not self._looping_call: - # Start a looping call which runs every tick. - self._looping_call = self._clock.looping_call(self._run_once, 0) - - def stop(self): - """Stops calling resumeProducing.""" - if self._looping_call: - self._looping_call.stop() - self._looping_call = None - - def pauseProducing(self): - """Implements IPushProducer""" - self.stop() - - def resumeProducing(self): - """Implements IPushProducer""" - self._start_loop() - - def stopProducing(self): - """Implements IPushProducer""" - self.stop() - self._producer.stopProducing() - - def _run_once(self): - """Calls resumeProducing on producer once.""" - - try: - self._producer.resumeProducing() - except Exception: - logger.exception("Failed to call resumeProducing") - try: - self._consumer.unregisterProducer() - except Exception: - pass - - self.stopProducing() - - class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 323237c1bb..f51fa0a79e 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -239,7 +239,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # the state rows are unsorted state_rows = [] # type: List[EventsStreamCurrentStateRow] - for stream_name, token, row in received_rows: + for stream_name, _, row in received_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") @@ -356,7 +356,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # the state rows are unsorted state_rows = [] # type: List[EventsStreamCurrentStateRow] - for j in range(STATES_PER_USER + 1): + for _ in range(STATES_PER_USER + 1): stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index ecbee30bb5..120730b764 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -430,7 +430,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ # Create devices number_devices = 5 - for n in range(number_devices): + for _ in range(number_devices): self.login("user", "pass") # Get devices @@ -547,7 +547,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Create devices number_devices = 5 - for n in range(number_devices): + for _ in range(number_devices): self.login("user", "pass") # Get devices diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 8c66da3af4..29341bc6e9 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -48,22 +48,22 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok) # Two rooms and two users. Every user sends and reports every room event - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id1, user_tok=self.other_user_tok, ) - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id2, user_tok=self.other_user_tok, ) - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id1, user_tok=self.admin_user_tok, ) - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id2, user_tok=self.admin_user_tok, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 6bcd997085..ee071c2477 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -17,6 +17,8 @@ import urllib.parse from typing import List, Optional from unittest.mock import Mock +from parameterized import parameterized_class + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes @@ -144,6 +146,13 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) +@parameterized_class( + ("method", "url_template"), + [ + ("POST", "/_synapse/admin/v1/rooms/%s/delete"), + ("DELETE", "/_synapse/admin/v1/rooms/%s"), + ], +) class DeleteRoomTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -175,7 +184,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as( self.other_user, tok=self.other_user_tok ) - self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + self.url = self.url_template % self.room_id def test_requester_is_no_admin(self): """ @@ -183,7 +192,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "POST", + self.method, self.url, json.dumps({}), access_token=self.other_user_tok, @@ -196,10 +205,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Check that unknown rooms/server return error 404. """ - url = "/_synapse/admin/v1/rooms/!unknown:test/delete" + url = self.url_template % "!unknown:test" channel = self.make_request( - "POST", + self.method, url, json.dumps({}), access_token=self.admin_user_tok, @@ -212,10 +221,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Check that invalid room names, return an error 400. """ - url = "/_synapse/admin/v1/rooms/invalidroom/delete" + url = self.url_template % "invalidroom" channel = self.make_request( - "POST", + self.method, url, json.dumps({}), access_token=self.admin_user_tok, @@ -234,7 +243,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"new_room_user_id": "@unknown:test"}) channel = self.make_request( - "POST", + self.method, self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -253,7 +262,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"new_room_user_id": "@not:exist.bla"}) channel = self.make_request( - "POST", + self.method, self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -272,7 +281,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": "NotBool"}) channel = self.make_request( - "POST", + self.method, self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -288,7 +297,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"purge": "NotBool"}) channel = self.make_request( - "POST", + self.method, self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -314,7 +323,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": True, "purge": True}) channel = self.make_request( - "POST", + self.method, self.url.encode("ascii"), content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -347,7 +356,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": True}) channel = self.make_request( - "POST", + self.method, self.url.encode("ascii"), content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -381,7 +390,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": False}) channel = self.make_request( - "POST", + self.method, self.url.encode("ascii"), content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, @@ -426,10 +435,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_member(room_id=self.room_id, user_id=self.other_user) # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id channel = self.make_request( - "POST", - url.encode("ascii"), + self.method, + self.url, json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -473,10 +481,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_member(room_id=self.room_id, user_id=self.other_user) # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id channel = self.make_request( - "POST", - url.encode("ascii"), + self.method, + self.url, json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -615,7 +622,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create 3 test rooms total_rooms = 3 room_ids = [] - for x in range(total_rooms): + for _ in range(total_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) @@ -679,7 +686,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create 5 test rooms total_rooms = 5 room_ids = [] - for x in range(total_rooms): + for _ in range(total_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) @@ -1577,7 +1584,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel.json_body["event"]["event_id"], events[midway]["event_id"] ) - for i, found_event in enumerate(channel.json_body["events_before"]): + for found_event in channel.json_body["events_before"]: for j, posted_event in enumerate(events): if found_event["event_id"] == posted_event["event_id"]: self.assertTrue(j < midway) @@ -1585,7 +1592,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): else: self.fail("Event %s from events_before not found" % j) - for i, found_event in enumerate(channel.json_body["events_after"]): + for found_event in channel.json_body["events_after"]: for j, posted_event in enumerate(events): if found_event["event_id"] == posted_event["event_id"]: self.assertTrue(j > midway) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 363bdeeb2d..79cac4266b 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -467,7 +467,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): number_media: Number of media to be created for the user """ upload_resource = self.media_repo.children[b"upload"] - for i in range(number_media): + for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( b"89504e470d0a1a0a0000000d4948445200000001000000010806" diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 2844c493fc..d599a4c984 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,7 +18,7 @@ import json import urllib.parse from binascii import unhexlify from typing import List, Optional -from unittest.mock import Mock +from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import UserTypes @@ -54,8 +54,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.datastore = Mock(return_value=Mock()) self.datastore.get_current_state_deltas = Mock(return_value=(0, [])) - self.secrets = Mock() - self.hs = self.setup_test_homeserver() self.hs.config.registration_shared_secret = "shared" @@ -84,14 +82,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): Calling GET on the endpoint will return a randomised nonce, using the homeserver's secrets provider. """ - secrets = Mock() - secrets.token_hex = Mock(return_value="abcd") - - self.hs.get_secrets = Mock(return_value=secrets) + with patch("secrets.token_hex") as token_hex: + # Patch secrets.token_hex for the duration of this context + token_hex.return_value = "abcd" - channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) - self.assertEqual(channel.json_body, {"nonce": "abcd"}) + self.assertEqual(channel.json_body, {"nonce": "abcd"}) def test_expired_nonce(self): """ @@ -1937,7 +1934,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): # Create rooms and join other_user_tok = self.login("user", "pass") number_rooms = 5 - for n in range(number_rooms): + for _ in range(number_rooms): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms @@ -2517,7 +2514,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): user_token: Access token of the user number_media: Number of media to be created for the user """ - for i in range(number_media): + for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( b"89504e470d0a1a0a0000000d4948445200000001000000010806" diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 3a050659ca..409f3949dc 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -16,6 +16,7 @@ from unittest.mock import Mock from twisted.internet import defer +from synapse.handlers.presence import PresenceHandler from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -32,7 +33,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - presence_handler = Mock() + presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) hs = self.setup_test_homeserver( @@ -59,12 +60,12 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + @unittest.override_config({"use_presence": False}) def test_put_presence_disabled(self): """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. """ - self.hs.config.use_presence = False body = {"presence": "here", "status_msg": "beep boop"} channel = self.make_request( diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 92babf65e0..a3694f3d02 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -646,7 +646,7 @@ class RoomInviteRatelimitTestCase(RoomBase): def test_invites_by_users_ratelimit(self): """Tests that invites to a specific user are actually rate-limited.""" - for i in range(3): + for _ in range(3): room_id = self.helper.create_room_as(self.user_id) self.helper.invite(room_id, self.user_id, "@other-users:red") @@ -668,7 +668,7 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self): """Tests that local joins are actually rate-limited.""" - for i in range(3): + for _ in range(3): self.helper.create_room_as(self.user_id) self.helper.create_room_as(self.user_id, expect_code=429) @@ -733,7 +733,7 @@ class RoomJoinRatelimitTestCase(RoomBase): for path in paths_to_test: # Make sure we send more requests than the rate-limiting config would allow # if all of these requests ended up joining the user to a room. - for i in range(4): + for _ in range(4): channel = self.make_request("POST", path % room_id, {}) self.assertEquals(channel.code, 200) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 054d4e4140..1cad5f00eb 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -310,6 +310,57 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(channel.json_body.get("sid")) + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_reject_invalid_email(self): + """Check that bad emails are rejected""" + + # Test for email with multiple @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + # Check error to ensure that we're not erroring due to a bug in the test. + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for email with no @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for super long email + email = "a@" + "a" * 1000 + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): @@ -492,8 +543,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - # Move 6 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=6).total_seconds()) + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) self.assertEqual(len(self.email_attempts), 1) # Retrieving the URL from the email is too much pain for now, so we @@ -504,14 +555,32 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect on a successful renewal. - expected_html = self.hs.config.account_validity.account_renewed_html_content + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -531,15 +600,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect when using an # invalid/unknown token. - expected_html = self.hs.config.account_validity.invalid_token_html_content + expected_html = ( + self.hs.config.account_validity.account_validity_invalid_token_template.render() + ) self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -647,7 +715,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) - self.hs.config.account_validity.period = self.validity_period + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta self.store = self.hs.get_datastore() diff --git a/tests/server.py b/tests/server.py index b535a5d886..9df8cda24f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -603,12 +603,6 @@ class FakeTransport: if self.disconnected: return - if not hasattr(self.other, "transport"): - # the other has no transport yet; reschedule - if self.autoflush: - self._reactor.callLater(0.0, self.flush) - return - if maxbytes is not None: to_write = self.buffer[:maxbytes] else: diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 6339a43f0c..200b9198f9 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import secrets from tests import unittest @@ -21,7 +22,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.storage = hs.get_datastore() - self.table_name = "table_" + hs.get_secrets().token_hex(6) + self.table_name = "table_" + secrets.token_hex(6) self.get_success( self.storage.db_pool.runInteraction( "create", diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 397e68fe0a..088fbb247b 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -38,12 +38,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase): last_event = None # Make a real event chain - for i in range(event_count): + for _ in range(event_count): ev = self.create_and_send_event(room_id, user, False, last_event) last_event = [ev] # Sprinkle in some extremities - for i in range(extrems): + for _ in range(extrems): ev = self.create_and_send_event(room_id, user, False, last_event) # Let it run for a while, then pull out the statistics from the diff --git a/tests/test_federation.py b/tests/test_federation.py index 0a3a996ec1..0ed8326f55 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -135,7 +135,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(): + with LoggingContext("test-context"): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_server.py b/tests/test_server.py index 55cde7f62f..407e172e41 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase): parse_listener_def({"type": "http", "port": 0}), self.resource, "1.0", + max_request_body_size=1234, + reactor=self.reactor, ) # render the request and return the channel diff --git a/tests/unittest.py b/tests/unittest.py index d890ad981f..74db7c08f1 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -18,6 +18,7 @@ import hashlib import hmac import inspect import logging +import secrets import time from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from unittest.mock import Mock, patch @@ -133,7 +134,7 @@ class TestCase(unittest.TestCase): def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEquals.""" - for (key, value) in attrs.items(): + for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: @@ -247,6 +248,8 @@ class HomeserverTestCase(TestCase): config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", + max_request_body_size=1234, + reactor=self.reactor, ) from tests.rest.client.v1.utils import RestHelper @@ -624,7 +627,6 @@ class HomeserverTestCase(TestCase): str: The new event's ID. """ event_creator = self.hs.get_event_creation_handler() - secrets = self.hs.get_secrets() requester = create_requester(user) event, context = self.get_success( diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 40cd98e2d8..178ac8a68c 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -230,8 +230,7 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup(): - with LoggingContext() as c1: - c1.name = "c1" + with LoggingContext("c1") as c1: r = yield obj.fn(1) self.assertEqual(current_context(), c1) return r @@ -273,8 +272,7 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup(): - with LoggingContext() as c1: - c1.name = "c1" + with LoggingContext("c1") as c1: try: d = obj.fn(1) self.assertEqual( diff --git a/tests/utils.py b/tests/utils.py index af6b32fc66..6bd008dcfe 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -153,6 +153,10 @@ def default_config(name, parse=False): "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_invites": { + "per_room": {"per_second": 10000, "burst_count": 10000}, + "per_user": {"per_second": 10000, "burst_count": 10000}, + }, "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, "public_baseurl": None, @@ -303,7 +307,7 @@ def setup_test_homeserver( # database for a few more seconds due to flakiness, preventing # us from dropping it when the test is over. If we can't drop # it, warn and move on. - for x in range(5): + for _ in range(5): try: cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) db_conn.commit() diff --git a/tox.ini b/tox.ini index 998b04b224..ecd609271d 100644 --- a/tox.ini +++ b/tox.ini @@ -21,13 +21,11 @@ deps = # installed on that). # # anyway, make sure that we have a recent enough setuptools. - setuptools>=18.5 ; python_version >= '3.6' - setuptools>=18.5,<51.0.0 ; python_version < '3.6' + setuptools>=18.5 # we also need a semi-recent version of pip, because old ones fail to # install the "enum34" dependency of cryptography. - pip>=10 ; python_version >= '3.6' - pip>=10,<21.0 ; python_version < '3.6' + pip>=10 # directories/files we run the linters on. # if you update this list, make sure to do the same in scripts-dev/lint.sh @@ -168,8 +166,7 @@ skip_install = true usedevelop = false deps = coverage - pip>=10 ; python_version >= '3.6' - pip>=10,<21.0 ; python_version < '3.6' + pip>=10 commands= coverage combine coverage report |