diff options
author | Eric Eastwood <erice@element.io> | 2021-12-09 02:53:17 -0600 |
---|---|---|
committer | Eric Eastwood <erice@element.io> | 2021-12-09 02:53:17 -0600 |
commit | 6b64184585b4935c75e6f4ed44ac17ca48e1f71f (patch) | |
tree | 048c8b90df46d609514b8bba2c6df5bb809225d2 | |
parent | Merge branch 'develop' into madlittlemods/return-historical-events-in-order-f... (diff) | |
parent | Add a constant for receipt types (m.read). (#11531) (diff) | |
download | synapse-6b64184585b4935c75e6f4ed44ac17ca48e1f71f.tar.xz |
Merge branch 'develop' into madlittlemods/return-historical-events-in-order-from-backfill
Conflicts: scripts-dev/complement.sh synapse/handlers/room_batch.py
356 files changed, 13482 insertions, 4590 deletions
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index af7ed21fce..3276d1e122 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -5,7 +5,7 @@ name: Build docker images on: push: tags: ["v*"] - branches: [ master, main ] + branches: [ master, main, develop ] workflow_dispatch: permissions: @@ -38,6 +38,9 @@ jobs: id: set-tag run: | case "${GITHUB_REF}" in + refs/heads/develop) + tag=develop + ;; refs/heads/master|refs/heads/main) tag=latest ;; diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8d7e8cafd9..21c9ee7823 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -374,7 +374,7 @@ jobs: working-directory: complement/dockerfiles # Run Complement - - run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests/... + - run: go test -v -tags synapse_blacklist,msc2403 ./tests/... env: COMPLEMENT_BASE_IMAGE: complement-synapse:latest working-directory: complement diff --git a/CHANGES.md b/CHANGES.md index e74544f489..72e8d64cf7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,317 @@ +Synapse 1.49.0rc1 (2021-12-07) +============================== + +We've decided to move the existing, somewhat stagnant pages from the GitHub wiki +to the [documentation website](https://matrix-org.github.io/synapse/latest/). + +This was done for two reasons. The first was to ensure that changes are checked by +multiple authors before being committed (everyone makes mistakes!) and the second +was visibility of the documentation. Not everyone knows that Synapse has some very +useful information hidden away in its GitHub wiki pages. Bringing them to the +documentation website should help with visibility, as well as keep all Synapse documentation +in one, easily-searchable location. + +Note that contributions to the documentation website happen through [GitHub pull +requests](https://github.com/matrix-org/synapse/pulls). Please visit [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org) +if you need help with the process! + + +Features +-------- + +- Add [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) experimental client and federation API endpoints to get the closest event to a given timestamp. ([\#9445](https://github.com/matrix-org/synapse/issues/9445)) +- Include bundled relation aggregations during a limited `/sync` request and `/relations` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11284](https://github.com/matrix-org/synapse/issues/11284), [\#11478](https://github.com/matrix-org/synapse/issues/11478)) +- Add plugin support for controlling database background updates. ([\#11306](https://github.com/matrix-org/synapse/issues/11306), [\#11475](https://github.com/matrix-org/synapse/issues/11475), [\#11479](https://github.com/matrix-org/synapse/issues/11479)) +- Support the stable API endpoints for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): the room `/hierarchy` endpoint. ([\#11329](https://github.com/matrix-org/synapse/issues/11329)) +- Add admin API to get some information about federation status with remote servers. ([\#11407](https://github.com/matrix-org/synapse/issues/11407)) +- Support expiry of refresh tokens and expiry of the overall session when refresh tokens are in use. ([\#11425](https://github.com/matrix-org/synapse/issues/11425)) +- Stabilise support for [MSC2918](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) refresh tokens as they have now been merged into the Matrix specification. ([\#11435](https://github.com/matrix-org/synapse/issues/11435), [\#11522](https://github.com/matrix-org/synapse/issues/11522)) +- Update [MSC2918 refresh token](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) support to confirm with the latest revision: accept the `refresh_tokens` parameter in the request body rather than in the URL parameters. ([\#11430](https://github.com/matrix-org/synapse/issues/11430)) +- Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. ([\#11445](https://github.com/matrix-org/synapse/issues/11445)) +- Expose `synapse_homeserver` and `synapse_worker` commands as entry points to run Synapse's main process and worker processes, respectively. Contributed by @Ma27. ([\#11449](https://github.com/matrix-org/synapse/issues/11449)) +- `synctl stop` will now wait for Synapse to exit before returning. ([\#11459](https://github.com/matrix-org/synapse/issues/11459), [\#11490](https://github.com/matrix-org/synapse/issues/11490)) +- Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. ([\#11523](https://github.com/matrix-org/synapse/issues/11523)) +- Add support for the `/_matrix/client/v3/login/sso/redirect/{idpId}` API from Matrix v1.1. This endpoint was overlooked when support for v3 endpoints was added in Synapse 1.48.0rc1. ([\#11451](https://github.com/matrix-org/synapse/issues/11451)) + + +Bugfixes +-------- + +- Fix using [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) batch sending in combination with event persistence workers. Contributed by @tulir at Beeper. ([\#11220](https://github.com/matrix-org/synapse/issues/11220)) +- Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, properly this time. Also fix a race condition introduced in the previous insufficient fix in Synapse 1.47.0. ([\#11376](https://github.com/matrix-org/synapse/issues/11376)) +- The `/send_join` response now includes the stable `event` field instead of the unstable field from [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083). ([\#11413](https://github.com/matrix-org/synapse/issues/11413)) +- Fix a bug introduced in Synapse 1.47.0 where `send_join` could fail due to an outdated `ijson` version. ([\#11439](https://github.com/matrix-org/synapse/issues/11439), [\#11441](https://github.com/matrix-org/synapse/issues/11441), [\#11460](https://github.com/matrix-org/synapse/issues/11460)) +- Fix a bug introduced in Synapse 1.36.0 which could cause problems fetching event-signing keys from trusted key servers. ([\#11440](https://github.com/matrix-org/synapse/issues/11440)) +- Fix a bug introduced in Synapse 1.47.1 where the media repository would fail to work if the media store path contained any symbolic links. ([\#11446](https://github.com/matrix-org/synapse/issues/11446)) +- Fix an `LruCache` corruption bug, introduced in Synapse 1.38.0, that would cause certain requests to fail until the next Synapse restart. ([\#11454](https://github.com/matrix-org/synapse/issues/11454)) +- Fix a long-standing bug where invites from ignored users were included in incremental syncs. ([\#11511](https://github.com/matrix-org/synapse/issues/11511)) +- Fix a regression in Synapse 1.48.0 where presence workers would not clear their presence updates over replication on shutdown. ([\#11518](https://github.com/matrix-org/synapse/issues/11518)) +- Fix a regression in Synapse 1.48.0 where the module API's `looping_background_call` method would spam errors to the logs when given a non-async function. ([\#11524](https://github.com/matrix-org/synapse/issues/11524)) + + +Updates to the Docker image +--------------------------- + +- Update `Dockerfile-workers` to healthcheck all workers in the container. ([\#11429](https://github.com/matrix-org/synapse/issues/11429)) + + +Improved Documentation +---------------------- + +- Update the media repository documentation. ([\#11415](https://github.com/matrix-org/synapse/issues/11415)) +- Update section about backward extremities in the room DAG concepts doc to correct the misconception about backward extremities indicating whether we have fetched an events' `prev_events`. ([\#11469](https://github.com/matrix-org/synapse/issues/11469)) + + +Internal Changes +---------------- + +- Add `Final` annotation to string constants in `synapse.api.constants` so that they get typed as `Literal`s. ([\#11356](https://github.com/matrix-org/synapse/issues/11356)) +- Add a check to ensure that users cannot start the Synapse master process when `worker_app` is set. ([\#11416](https://github.com/matrix-org/synapse/issues/11416)) +- Add a note about postgres memory management and hugepages to postgres doc. ([\#11467](https://github.com/matrix-org/synapse/issues/11467)) +- Add missing type hints to `synapse.config` module. ([\#11465](https://github.com/matrix-org/synapse/issues/11465)) +- Add missing type hints to `synapse.federation`. ([\#11483](https://github.com/matrix-org/synapse/issues/11483)) +- Add type annotations to `tests.storage.test_appservice`. ([\#11488](https://github.com/matrix-org/synapse/issues/11488), [\#11492](https://github.com/matrix-org/synapse/issues/11492)) +- Add type annotations to some of the configuration surrounding refresh tokens. ([\#11428](https://github.com/matrix-org/synapse/issues/11428)) +- Add type hints to `synapse/tests/rest/admin`. ([\#11501](https://github.com/matrix-org/synapse/issues/11501)) +- Add type hints to storage classes. ([\#11411](https://github.com/matrix-org/synapse/issues/11411)) +- Add wiki pages to documentation website. ([\#11402](https://github.com/matrix-org/synapse/issues/11402)) +- Clean up `tests.storage.test_main` to remove use of legacy code. ([\#11493](https://github.com/matrix-org/synapse/issues/11493)) +- Clean up `tests.test_visibility` to remove legacy code. ([\#11495](https://github.com/matrix-org/synapse/issues/11495)) +- Convert status codes to `HTTPStatus` in `synapse.rest.admin`. ([\#11452](https://github.com/matrix-org/synapse/issues/11452), [\#11455](https://github.com/matrix-org/synapse/issues/11455)) +- Extend the `scripts-dev/sign_json` script to support signing events. ([\#11486](https://github.com/matrix-org/synapse/issues/11486)) +- Improve internal types in push code. ([\#11409](https://github.com/matrix-org/synapse/issues/11409)) +- Improve type annotations in `synapse.module_api`. ([\#11029](https://github.com/matrix-org/synapse/issues/11029)) +- Improve type hints for `LruCache`. ([\#11453](https://github.com/matrix-org/synapse/issues/11453)) +- Preparation for database schema simplifications: disambiguate queries on `state_key`. ([\#11497](https://github.com/matrix-org/synapse/issues/11497)) +- Refactor `backfilled` into specific behavior function arguments (`_persist_events_and_state_updates` and downstream calls). ([\#11417](https://github.com/matrix-org/synapse/issues/11417)) +- Refactor `get_version_string` to fix-up types and duplicated code. ([\#11468](https://github.com/matrix-org/synapse/issues/11468)) +- Refactor various parts of the `/sync` handler. ([\#11494](https://github.com/matrix-org/synapse/issues/11494), [\#11515](https://github.com/matrix-org/synapse/issues/11515)) +- Remove unnecessary `json.dumps` from `tests.rest.admin`. ([\#11461](https://github.com/matrix-org/synapse/issues/11461)) +- Save the OpenID Connect session ID on login. ([\#11482](https://github.com/matrix-org/synapse/issues/11482)) +- Update and clean up recently ported documentation pages. ([\#11466](https://github.com/matrix-org/synapse/issues/11466)) + + +Synapse 1.48.0 (2021-11-30) +=========================== + +This release removes support for the long-deprecated `trust_identity_server_for_password_resets` configuration flag. + +This release also fixes some performance issues with some background database updates introduced in Synapse 1.47.0. + +No significant changes since 1.48.0rc1. + +Synapse 1.48.0rc1 (2021-11-25) +============================== + +Features +-------- + +- Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). ([\#11161](https://github.com/matrix-org/synapse/issues/11161)) +- Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). ([\#11236](https://github.com/matrix-org/synapse/issues/11236)) +- Add support for the `/_matrix/client/v3` and `/_matrix/media/v3` APIs from Matrix v1.1. ([\#11318](https://github.com/matrix-org/synapse/issues/11318), [\#11371](https://github.com/matrix-org/synapse/issues/11371)) +- Support the stable version of [MSC2778](https://github.com/matrix-org/matrix-doc/pull/2778): the `m.login.application_service` login type. Contributed by @tulir. ([\#11335](https://github.com/matrix-org/synapse/issues/11335)) +- Add a new version of delete room admin API `DELETE /_synapse/admin/v2/rooms/<room_id>` to run it in the background. Contributed by @dklimpel. ([\#11223](https://github.com/matrix-org/synapse/issues/11223)) +- Allow the admin [Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api) to block a room without the need to join it. ([\#11228](https://github.com/matrix-org/synapse/issues/11228)) +- Add an admin API to un-shadow-ban a user. ([\#11347](https://github.com/matrix-org/synapse/issues/11347)) +- Add an admin API to run background database schema updates. ([\#11352](https://github.com/matrix-org/synapse/issues/11352)) +- Add an admin API for blocking a room. ([\#11324](https://github.com/matrix-org/synapse/issues/11324)) +- Update the JWT login type to support custom a `sub` claim. ([\#11361](https://github.com/matrix-org/synapse/issues/11361)) +- Store and allow querying of arbitrary event relations. ([\#11391](https://github.com/matrix-org/synapse/issues/11391)) + + +Bugfixes +-------- + +- Fix a long-standing bug wherein display names or avatar URLs containing null bytes cause an internal server error when stored in the DB. ([\#11230](https://github.com/matrix-org/synapse/issues/11230)) +- Prevent [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical state events from being pushed to an application service via `/transactions`. ([\#11265](https://github.com/matrix-org/synapse/issues/11265)) +- Fix a long-standing bug where uploading extremely thin images (e.g. 1000x1) would fail. Contributed by @Neeeflix. ([\#11288](https://github.com/matrix-org/synapse/issues/11288)) +- Fix a bug, introduced in Synapse 1.46.0, which caused the `check_3pid_auth` and `on_logged_out` callbacks in legacy password authentication provider modules to not be registered. Modules using the generic module interface were not affected. ([\#11340](https://github.com/matrix-org/synapse/issues/11340)) +- Fix a bug introduced in 1.41.0 where space hierarchy responses would be incorrectly reused if multiple users were to make the same request at the same time. ([\#11355](https://github.com/matrix-org/synapse/issues/11355)) +- Fix a bug introduced in 1.45.0 where the `read_templates` method of the module API would error. ([\#11377](https://github.com/matrix-org/synapse/issues/11377)) +- Fix an issue introduced in 1.47.0 which prevented servers re-joining rooms they had previously left, if their signing keys were replaced. ([\#11379](https://github.com/matrix-org/synapse/issues/11379)) +- Fix a bug introduced in 1.13.0 where creating and publishing a room could cause errors if `room_list_publication_rules` is configured. ([\#11392](https://github.com/matrix-org/synapse/issues/11392)) +- Improve performance of various background database updates. ([\#11421](https://github.com/matrix-org/synapse/issues/11421), [\#11422](https://github.com/matrix-org/synapse/issues/11422)) + + +Improved Documentation +---------------------- + +- Suggest users of the Debian packages add configuration to `/etc/matrix-synapse/conf.d/` to prevent, upon upgrade, being asked to choose between their configuration and the maintainer's. ([\#11281](https://github.com/matrix-org/synapse/issues/11281)) +- Fix typos in the documentation for the `username_available` admin API. Contributed by Stanislav Motylkov. ([\#11286](https://github.com/matrix-org/synapse/issues/11286)) +- Add Single Sign-On, SAML and CAS pages to the documentation. ([\#11298](https://github.com/matrix-org/synapse/issues/11298)) +- Change the word 'Home server' as one word 'homeserver' in documentation. ([\#11320](https://github.com/matrix-org/synapse/issues/11320)) +- Fix missing quotes for wildcard domains in `federation_certificate_verification_whitelist`. ([\#11381](https://github.com/matrix-org/synapse/issues/11381)) + + +Deprecations and Removals +------------------------- + +- Remove deprecated `trust_identity_server_for_password_resets` configuration flag. ([\#11333](https://github.com/matrix-org/synapse/issues/11333), [\#11395](https://github.com/matrix-org/synapse/issues/11395)) + + +Internal Changes +---------------- + +- Add type annotations to `synapse.metrics`. ([\#10847](https://github.com/matrix-org/synapse/issues/10847)) +- Split out federated PDU retrieval function into a non-cached version. ([\#11242](https://github.com/matrix-org/synapse/issues/11242)) +- Clean up code relating to to-device messages and sending ephemeral events to application services. ([\#11247](https://github.com/matrix-org/synapse/issues/11247)) +- Fix a small typo in the error response when a relation type other than 'm.annotation' is passed to `GET /rooms/{room_id}/aggregations/{event_id}`. ([\#11278](https://github.com/matrix-org/synapse/issues/11278)) +- Drop unused database tables `room_stats_historical` and `user_stats_historical`. ([\#11280](https://github.com/matrix-org/synapse/issues/11280)) +- Require all files in synapse/ and tests/ to pass mypy unless specifically excluded. ([\#11282](https://github.com/matrix-org/synapse/issues/11282), [\#11285](https://github.com/matrix-org/synapse/issues/11285), [\#11359](https://github.com/matrix-org/synapse/issues/11359)) +- Add missing type hints to `synapse.app`. ([\#11287](https://github.com/matrix-org/synapse/issues/11287)) +- Remove unused parameters on `FederationEventHandler._check_event_auth`. ([\#11292](https://github.com/matrix-org/synapse/issues/11292)) +- Add type hints to `synapse._scripts`. ([\#11297](https://github.com/matrix-org/synapse/issues/11297)) +- Fix an issue which prevented the `remove_deleted_devices_from_device_inbox` background database schema update from running when updating from a recent Synapse version. ([\#11303](https://github.com/matrix-org/synapse/issues/11303)) +- Add type hints to storage classes. ([\#11307](https://github.com/matrix-org/synapse/issues/11307), [\#11310](https://github.com/matrix-org/synapse/issues/11310), [\#11311](https://github.com/matrix-org/synapse/issues/11311), [\#11312](https://github.com/matrix-org/synapse/issues/11312), [\#11313](https://github.com/matrix-org/synapse/issues/11313), [\#11314](https://github.com/matrix-org/synapse/issues/11314), [\#11316](https://github.com/matrix-org/synapse/issues/11316), [\#11322](https://github.com/matrix-org/synapse/issues/11322), [\#11332](https://github.com/matrix-org/synapse/issues/11332), [\#11339](https://github.com/matrix-org/synapse/issues/11339), [\#11342](https://github.com/matrix-org/synapse/issues/11342)) +- Add type hints to `synapse.util`. ([\#11321](https://github.com/matrix-org/synapse/issues/11321), [\#11328](https://github.com/matrix-org/synapse/issues/11328)) +- Improve type annotations in Synapse's test suite. ([\#11323](https://github.com/matrix-org/synapse/issues/11323), [\#11330](https://github.com/matrix-org/synapse/issues/11330)) +- Test that room alias deletion works as intended. ([\#11327](https://github.com/matrix-org/synapse/issues/11327)) +- Add type annotations for some methods and properties in the module API. ([\#11341](https://github.com/matrix-org/synapse/issues/11341)) +- Fix running `scripts-dev/complement.sh`, which was broken in v1.47.0rc1. ([\#11368](https://github.com/matrix-org/synapse/issues/11368)) +- Rename internal functions for token generation to better reflect what they do. ([\#11369](https://github.com/matrix-org/synapse/issues/11369), [\#11370](https://github.com/matrix-org/synapse/issues/11370)) +- Add type hints to configuration classes. ([\#11377](https://github.com/matrix-org/synapse/issues/11377)) +- Publish a `develop` image to Docker Hub. ([\#11380](https://github.com/matrix-org/synapse/issues/11380)) +- Keep fallback key marked as used if it's re-uploaded. ([\#11382](https://github.com/matrix-org/synapse/issues/11382)) +- Use `auto_attribs` on the `attrs` class `RefreshTokenLookupResult`. ([\#11386](https://github.com/matrix-org/synapse/issues/11386)) +- Rename unstable `access_token_lifetime` configuration option to `refreshable_access_token_lifetime` to make it clear it only concerns refreshable access tokens. ([\#11388](https://github.com/matrix-org/synapse/issues/11388)) +- Do not run the broken MSC2716 tests when running `scripts-dev/complement.sh`. ([\#11389](https://github.com/matrix-org/synapse/issues/11389)) +- Remove dead code from supporting ACME. ([\#11393](https://github.com/matrix-org/synapse/issues/11393)) +- Refactor including the bundled relations when serializing an event. ([\#11408](https://github.com/matrix-org/synapse/issues/11408)) + + +Synapse 1.47.1 (2021-11-23) +=========================== + +This release fixes a security issue in the media store, affecting all prior releases of Synapse. Server administrators are encouraged to update Synapse as soon as possible. We are not aware of these vulnerabilities being exploited in the wild. + +Server administrators who are unable to update Synapse may use the workarounds described in the linked GitHub Security Advisory below. + +Security advisory +----------------- + +The following issue is fixed in 1.47.1. + +- **[GHSA-3hfw-x7gx-437c](https://github.com/matrix-org/synapse/security/advisories/GHSA-3hfw-x7gx-437c) / [CVE-2021-41281](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-41281): Path traversal when downloading remote media.** + + Synapse instances with the media repository enabled can be tricked into downloading a file from a remote server into an arbitrary directory, potentially outside the media store directory. + + The last two directories and file name of the path are chosen randomly by Synapse and cannot be controlled by an attacker, which limits the impact. + + Homeservers with the media repository disabled are unaffected. Homeservers configured with a federation whitelist are also unaffected. + + Fixed by [91f2bd090](https://github.com/matrix-org/synapse/commit/91f2bd090). + + +Synapse 1.47.0 (2021-11-17) +=========================== + +No significant changes since 1.47.0rc3. + + +Synapse 1.47.0rc3 (2021-11-16) +============================== + +Bugfixes +-------- + +- Fix a bug introduced in 1.47.0rc1 which caused worker processes to not halt startup in the presence of outstanding database migrations. ([\#11346](https://github.com/matrix-org/synapse/issues/11346)) +- Fix a bug introduced in 1.47.0rc1 which prevented the 'remove deleted devices from `device_inbox` column' background process from running when updating from a recent Synapse version. ([\#11303](https://github.com/matrix-org/synapse/issues/11303), [\#11353](https://github.com/matrix-org/synapse/issues/11353)) + + +Synapse 1.47.0rc2 (2021-11-10) +============================== + +This fixes an issue with publishing the Debian packages for 1.47.0rc1. +It is otherwise identical to 1.47.0rc1. + + +Synapse 1.47.0rc1 (2021-11-09) +============================== + +Deprecations and Removals +------------------------- + +- The `user_may_create_room_with_invites` module callback is now deprecated. Please refer to the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#upgrading-to-v1470) for more information. ([\#11206](https://github.com/matrix-org/synapse/issues/11206)) +- Remove deprecated admin API to delete rooms (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). ([\#11213](https://github.com/matrix-org/synapse/issues/11213)) + + +Features +-------- + +- Advertise support for Client-Server API r0.6.1. ([\#11097](https://github.com/matrix-org/synapse/issues/11097)) +- Add search by room ID and room alias to the List Room admin API. ([\#11099](https://github.com/matrix-org/synapse/issues/11099)) +- Add an `on_new_event` third-party rules callback to allow Synapse modules to act after an event has been sent into a room. ([\#11126](https://github.com/matrix-org/synapse/issues/11126)) +- Add a module API method to update a user's membership in a room. ([\#11147](https://github.com/matrix-org/synapse/issues/11147)) +- Add metrics for thread pool usage. ([\#11178](https://github.com/matrix-org/synapse/issues/11178)) +- Support the stable room type field for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288). ([\#11187](https://github.com/matrix-org/synapse/issues/11187)) +- Add a module API method to retrieve the current state of a room. ([\#11204](https://github.com/matrix-org/synapse/issues/11204)) +- Calculate a default value for `public_baseurl` based on `server_name`. ([\#11210](https://github.com/matrix-org/synapse/issues/11210)) +- Add support for serving `/.well-known/matrix/server` files, to redirect federation traffic to port 443. ([\#11211](https://github.com/matrix-org/synapse/issues/11211)) +- Add admin APIs to pause, start and check the status of background updates. ([\#11263](https://github.com/matrix-org/synapse/issues/11263)) + + +Bugfixes +-------- + +- Fix a long-standing bug which allowed hidden devices to receive to-device messages, resulting in unnecessary database bloat. ([\#10097](https://github.com/matrix-org/synapse/issues/10097)) +- Fix a long-standing bug where messages in the `device_inbox` table for deleted devices would persist indefinitely. Contributed by @dklimpel and @JohannesKleine. ([\#10969](https://github.com/matrix-org/synapse/issues/10969), [\#11212](https://github.com/matrix-org/synapse/issues/11212)) +- Do not accept events if a third-party rule `check_event_allowed` callback raises an exception. ([\#11033](https://github.com/matrix-org/synapse/issues/11033)) +- Fix long-standing bug where verification requests could fail in certain cases if a federation whitelist was in place but did not include your own homeserver. ([\#11129](https://github.com/matrix-org/synapse/issues/11129)) +- Allow an empty list of `state_events_at_start` to be sent when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint and the author of the historical messages is already part of the current room state at the given `?prev_event_id`. ([\#11188](https://github.com/matrix-org/synapse/issues/11188)) +- Fix a bug introduced in Synapse 1.45.0 which prevented the `synapse_review_recent_signups` script from running. Contributed by @samuel-p. ([\#11191](https://github.com/matrix-org/synapse/issues/11191)) +- Delete `to_device` messages for hidden devices that will never be read, reducing database size. ([\#11199](https://github.com/matrix-org/synapse/issues/11199)) +- Fix a long-standing bug wherein a missing `Content-Type` header when downloading remote media would cause Synapse to throw an error. ([\#11200](https://github.com/matrix-org/synapse/issues/11200)) +- Fix a long-standing bug which could result in serialization errors and potentially duplicate transaction data when sending ephemeral events to application services. Contributed by @Fizzadar at Beeper. ([\#11207](https://github.com/matrix-org/synapse/issues/11207)) +- Fix a bug introduced in Synapse 1.35.0 which made it impossible to join rooms that return a `send_join` response containing floats. ([\#11217](https://github.com/matrix-org/synapse/issues/11217)) +- Fix long-standing bug where cross signing keys were not included in the response to `/r0/keys/query` the first time a remote user was queried. ([\#11234](https://github.com/matrix-org/synapse/issues/11234)) +- Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection. ([\#11240](https://github.com/matrix-org/synapse/issues/11240)) +- Fix a bug preventing Synapse from being rolled back to an earlier version when using workers. ([\#11255](https://github.com/matrix-org/synapse/issues/11255), [\#11276](https://github.com/matrix-org/synapse/issues/11276)) +- Fix a bug introduced in Synapse 1.37.1 which caused a remote event being processed by a worker to not get processed on restart if the worker was killed. ([\#11262](https://github.com/matrix-org/synapse/issues/11262)) +- Only allow old Element/Riot Android clients to send read receipts without a request body. All other clients must include a request body as required by the specification. Contributed by @rogersheu. ([\#11157](https://github.com/matrix-org/synapse/issues/11157)) + + +Updates to the Docker image +--------------------------- + +- Avoid changing user ID when started as a non-root user, and no explicit `UID` is set. ([\#11209](https://github.com/matrix-org/synapse/issues/11209)) + + +Improved Documentation +---------------------- + +- Improve example HAProxy config in the docs to properly handle HTTP `Host` headers with port information. This is required for federation over port 443 to work correctly. ([\#11128](https://github.com/matrix-org/synapse/issues/11128)) +- Add documentation for using Authentik as an OpenID Connect Identity Provider. Contributed by @samip5. ([\#11151](https://github.com/matrix-org/synapse/issues/11151)) +- Clarify lack of support for Windows. ([\#11198](https://github.com/matrix-org/synapse/issues/11198)) +- Improve code formatting and fix a few typos in docs. Contributed by @sumnerevans at Beeper. ([\#11221](https://github.com/matrix-org/synapse/issues/11221)) +- Add documentation for using LemonLDAP as an OpenID Connect Identity Provider. Contributed by @l00ptr. ([\#11257](https://github.com/matrix-org/synapse/issues/11257)) + + +Internal Changes +---------------- + +- Add type annotations for the `log_function` decorator. ([\#10943](https://github.com/matrix-org/synapse/issues/10943)) +- Add type hints to `synapse.events`. ([\#11098](https://github.com/matrix-org/synapse/issues/11098)) +- Remove and document unnecessary `RoomStreamToken` checks in application service ephemeral event code. ([\#11137](https://github.com/matrix-org/synapse/issues/11137)) +- Add type hints so that `synapse.http` passes `mypy` checks. ([\#11164](https://github.com/matrix-org/synapse/issues/11164)) +- Update scripts to pass Shellcheck lints. ([\#11166](https://github.com/matrix-org/synapse/issues/11166)) +- Add knock information in admin export. Contributed by Rafael Gonçalves. ([\#11171](https://github.com/matrix-org/synapse/issues/11171)) +- Add tests to check that `ClientIpStore.get_last_client_ip_by_device` and `get_user_ip_and_agents` combine database and in-memory data correctly. ([\#11179](https://github.com/matrix-org/synapse/issues/11179)) +- Refactor `Filter` to check different fields depending on the data type. ([\#11194](https://github.com/matrix-org/synapse/issues/11194)) +- Improve type hints for the relations datastore. ([\#11205](https://github.com/matrix-org/synapse/issues/11205)) +- Replace outdated links in the pull request checklist with links to the rendered documentation. ([\#11225](https://github.com/matrix-org/synapse/issues/11225)) +- Fix a bug in unit test `test_block_room_and_not_purge`. ([\#11226](https://github.com/matrix-org/synapse/issues/11226)) +- In `ObservableDeferred`, run observers in the order they were registered. ([\#11229](https://github.com/matrix-org/synapse/issues/11229)) +- Minor speed up to start up times and getting updates for groups by adding missing index to `local_group_updates.stream_id`. ([\#11231](https://github.com/matrix-org/synapse/issues/11231)) +- Add `twine` and `towncrier` as dev dependencies, as they're used by the release script. ([\#11233](https://github.com/matrix-org/synapse/issues/11233)) +- Allow `stream_writers.typing` config to be a list of one worker. ([\#11237](https://github.com/matrix-org/synapse/issues/11237)) +- Remove debugging statement in tests. ([\#11239](https://github.com/matrix-org/synapse/issues/11239)) +- Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical messages backfilling in random order on remote homeservers. ([\#11244](https://github.com/matrix-org/synapse/issues/11244)) +- Add an additional test for the `cachedList` method decorator. ([\#11246](https://github.com/matrix-org/synapse/issues/11246)) +- Make minor correction to the type of `auth_checkers` callbacks. ([\#11253](https://github.com/matrix-org/synapse/issues/11253)) +- Clean up trivial aspects of the Debian package build tooling. ([\#11269](https://github.com/matrix-org/synapse/issues/11269), [\#11273](https://github.com/matrix-org/synapse/issues/11273)) +- Blacklist new SyTest that checks that key uploads are valid pending the validation being implemented in Synapse. ([\#11270](https://github.com/matrix-org/synapse/issues/11270)) + + Synapse 1.46.0 (2021-11-02) =========================== @@ -8584,14 +8898,14 @@ General: Federation: -- Add key distribution mechanisms for fetching public keys of unavailable remote home servers. See [Retrieving Server Keys](https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys) in the spec. +- Add key distribution mechanisms for fetching public keys of unavailable remote homeservers. See [Retrieving Server Keys](https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys) in the spec. Configuration: - Add support for multiple config files. - Add support for dictionaries in config files. - Remove support for specifying config options on the command line, except for: - - `--daemonize` - Daemonize the home server. + - `--daemonize` - Daemonize the homeserver. - `--manhole` - Turn on the twisted telnet manhole service on the given port. - `--database-path` - The path to a sqlite database to use. - `--verbose` - The verbosity level. @@ -8796,7 +9110,7 @@ This version adds support for using a TURN server. See docs/turn-howto.rst on ho Homeserver: - Add support for redaction of messages. -- Fix bug where inviting a user on a remote home server could take up to 20-30s. +- Fix bug where inviting a user on a remote homeserver could take up to 20-30s. - Implement a get current room state API. - Add support specifying and retrieving turn server configuration. @@ -8886,7 +9200,7 @@ Changes in synapse 0.2.3 (2014-09-12) Homeserver: -- Fix bug where we stopped sending events to remote home servers if a user from that home server left, even if there were some still in the room. +- Fix bug where we stopped sending events to remote homeservers if a user from that homeserver left, even if there were some still in the room. - Fix bugs in the state conflict resolution where it was incorrectly rejecting events. Webclient: diff --git a/changelog.d/10097.bugfix b/changelog.d/10097.bugfix deleted file mode 100644 index 5d3d9587c2..0000000000 --- a/changelog.d/10097.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which allowed hidden devices to receive to-device messages, resulting in unnecessary database bloat. diff --git a/changelog.d/10520.misc b/changelog.d/10520.misc new file mode 100644 index 0000000000..a911e165da --- /dev/null +++ b/changelog.d/10520.misc @@ -0,0 +1 @@ +Send and handle cross-signing messages using the stable prefix. diff --git a/changelog.d/10943.misc b/changelog.d/10943.misc deleted file mode 100644 index 3ce28d1a67..0000000000 --- a/changelog.d/10943.misc +++ /dev/null @@ -1 +0,0 @@ -Add type annotations for the `log_function` decorator. diff --git a/changelog.d/10969.bugfix b/changelog.d/10969.bugfix deleted file mode 100644 index 89c299b8e8..0000000000 --- a/changelog.d/10969.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where messages in the `device_inbox` table for deleted devices would persist indefinitely. Contributed by @dklimpel and @JohannesKleine. diff --git a/changelog.d/11033.bugfix b/changelog.d/11033.bugfix deleted file mode 100644 index fa99f187b8..0000000000 --- a/changelog.d/11033.bugfix +++ /dev/null @@ -1 +0,0 @@ -Do not accept events if a third-party rule module API callback raises an exception. diff --git a/changelog.d/11097.feature b/changelog.d/11097.feature deleted file mode 100644 index d7563a406c..0000000000 --- a/changelog.d/11097.feature +++ /dev/null @@ -1 +0,0 @@ -Advertise support for Client-Server API r0.6.1. \ No newline at end of file diff --git a/changelog.d/11098.misc b/changelog.d/11098.misc deleted file mode 100644 index 1e337bee54..0000000000 --- a/changelog.d/11098.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `synapse.events`. diff --git a/changelog.d/11099.feature b/changelog.d/11099.feature deleted file mode 100644 index c9126d4a9d..0000000000 --- a/changelog.d/11099.feature +++ /dev/null @@ -1 +0,0 @@ -Add search by room ID and room alias to List Room admin API. \ No newline at end of file diff --git a/changelog.d/11126.feature b/changelog.d/11126.feature deleted file mode 100644 index c6078fe081..0000000000 --- a/changelog.d/11126.feature +++ /dev/null @@ -1 +0,0 @@ -Add an `on_new_event` third-party rules callback to allow Synapse modules to act after an event has been sent into a room. diff --git a/changelog.d/11128.doc b/changelog.d/11128.doc deleted file mode 100644 index c024679218..0000000000 --- a/changelog.d/11128.doc +++ /dev/null @@ -1 +0,0 @@ -Improve example HAProxy config in the docs to properly handle host headers with port information. This is required for federation over port 443 to work correctly. diff --git a/changelog.d/11129.bugfix b/changelog.d/11129.bugfix deleted file mode 100644 index 5e9aa538ec..0000000000 --- a/changelog.d/11129.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix long-standing bug where verification requests could fail in certain cases if whitelist was in place but did not include your own homeserver. \ No newline at end of file diff --git a/changelog.d/11137.misc b/changelog.d/11137.misc deleted file mode 100644 index f0d6476f48..0000000000 --- a/changelog.d/11137.misc +++ /dev/null @@ -1 +0,0 @@ -Remove and document unnecessary `RoomStreamToken` checks in application service ephemeral event code. \ No newline at end of file diff --git a/changelog.d/11147.feature b/changelog.d/11147.feature deleted file mode 100644 index af72d85c20..0000000000 --- a/changelog.d/11147.feature +++ /dev/null @@ -1 +0,0 @@ -Add a module API method to update a user's membership in a room. diff --git a/changelog.d/11151.doc b/changelog.d/11151.doc deleted file mode 100644 index 68cd99471f..0000000000 --- a/changelog.d/11151.doc +++ /dev/null @@ -1 +0,0 @@ -Add documentation for using Authentik as an OpenID Connect Identity Provider. Contributed by @samip5. \ No newline at end of file diff --git a/changelog.d/11157.misc b/changelog.d/11157.misc deleted file mode 100644 index 75444c51d1..0000000000 --- a/changelog.d/11157.misc +++ /dev/null @@ -1 +0,0 @@ -Only allow old Element/Riot Android clients to send read receipts without a request body. All other clients must include a request body as required by the specification. Contributed by @rogersheu. diff --git a/changelog.d/11164.misc b/changelog.d/11164.misc deleted file mode 100644 index 751da49183..0000000000 --- a/changelog.d/11164.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints so that `synapse.http` passes `mypy` checks. \ No newline at end of file diff --git a/changelog.d/11166.misc b/changelog.d/11166.misc deleted file mode 100644 index 79342e43d9..0000000000 --- a/changelog.d/11166.misc +++ /dev/null @@ -1 +0,0 @@ -Update scripts to pass Shellcheck lints. diff --git a/changelog.d/11171.misc b/changelog.d/11171.misc deleted file mode 100644 index b6a41a96da..0000000000 --- a/changelog.d/11171.misc +++ /dev/null @@ -1 +0,0 @@ -Add knock information in admin export. Contributed by Rafael Gonçalves. diff --git a/changelog.d/11178.feature b/changelog.d/11178.feature deleted file mode 100644 index 10b1cdffdc..0000000000 --- a/changelog.d/11178.feature +++ /dev/null @@ -1 +0,0 @@ -Add metrics for thread pool usage. diff --git a/changelog.d/11179.misc b/changelog.d/11179.misc deleted file mode 100644 index aded2e8367..0000000000 --- a/changelog.d/11179.misc +++ /dev/null @@ -1 +0,0 @@ -Add tests to check that `ClientIpStore.get_last_client_ip_by_device` and `get_user_ip_and_agents` combine database and in-memory data correctly. diff --git a/changelog.d/11187.feature b/changelog.d/11187.feature deleted file mode 100644 index dd28109030..0000000000 --- a/changelog.d/11187.feature +++ /dev/null @@ -1 +0,0 @@ -Support the stable room type field for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288). diff --git a/changelog.d/11188.bugfix b/changelog.d/11188.bugfix deleted file mode 100644 index 0688743c00..0000000000 --- a/changelog.d/11188.bugfix +++ /dev/null @@ -1 +0,0 @@ -Allow an empty list of `state_events_at_start` to be sent when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint and the author of the historical messages is already part of the current room state at the given `?prev_event_id`. diff --git a/changelog.d/11191.bugfix b/changelog.d/11191.bugfix deleted file mode 100644 index 9104db7f0e..0000000000 --- a/changelog.d/11191.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.45.0 which prevented the `synapse_review_recent_signups` script from running. Contributed by @samuel-p. diff --git a/changelog.d/11194.misc b/changelog.d/11194.misc deleted file mode 100644 index fc1d06ba89..0000000000 --- a/changelog.d/11194.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor `Filter` to check different fields depending on the data type. diff --git a/changelog.d/11198.doc b/changelog.d/11198.doc deleted file mode 100644 index 54ec94acbc..0000000000 --- a/changelog.d/11198.doc +++ /dev/null @@ -1 +0,0 @@ -Clarify lack of support for Windows. diff --git a/changelog.d/11199.bugfix b/changelog.d/11199.bugfix deleted file mode 100644 index dc3ea8d515..0000000000 --- a/changelog.d/11199.bugfix +++ /dev/null @@ -1 +0,0 @@ -Delete `to_device` messages for hidden devices that will never be read, reducing database size. \ No newline at end of file diff --git a/changelog.d/11200.bugfix b/changelog.d/11200.bugfix deleted file mode 100644 index c855081986..0000000000 --- a/changelog.d/11200.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug wherein a missing `Content-Type` header when downloading remote media would cause Synapse to throw an error. \ No newline at end of file diff --git a/changelog.d/11204.feature b/changelog.d/11204.feature deleted file mode 100644 index f58ed4b3dc..0000000000 --- a/changelog.d/11204.feature +++ /dev/null @@ -1 +0,0 @@ -Add a module API method to retrieve the current state of a room. diff --git a/changelog.d/11205.misc b/changelog.d/11205.misc deleted file mode 100644 index 62395c9432..0000000000 --- a/changelog.d/11205.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints for the relations datastore. diff --git a/changelog.d/11206.removal b/changelog.d/11206.removal deleted file mode 100644 index cf05b16672..0000000000 --- a/changelog.d/11206.removal +++ /dev/null @@ -1 +0,0 @@ -The `user_may_create_room_with_invites` module callback is now deprecated. Please refer to the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#upgrading-to-v1470) for more information. diff --git a/changelog.d/11207.bugfix b/changelog.d/11207.bugfix deleted file mode 100644 index 7e98d565a1..0000000000 --- a/changelog.d/11207.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which could result in serialization errors and potentially duplicate transaction data when sending ephemeral events to application services. Contributed by @Fizzadar at Beeper. diff --git a/changelog.d/11209.docker b/changelog.d/11209.docker deleted file mode 100644 index 838b165ac9..0000000000 --- a/changelog.d/11209.docker +++ /dev/null @@ -1 +0,0 @@ -Avoid changing userid when started as a non-root user, and no explicit `UID` is set. diff --git a/changelog.d/11210.feature b/changelog.d/11210.feature deleted file mode 100644 index 8f8e386415..0000000000 --- a/changelog.d/11210.feature +++ /dev/null @@ -1 +0,0 @@ -Calculate a default value for `public_baseurl` based on `server_name`. diff --git a/changelog.d/11211.feature b/changelog.d/11211.feature deleted file mode 100644 index feeb0cf089..0000000000 --- a/changelog.d/11211.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for serving `/.well-known/matrix/server` files, to redirect federation traffic to port 443. diff --git a/changelog.d/11212.bugfix b/changelog.d/11212.bugfix deleted file mode 100644 index ba6efab25b..0000000000 --- a/changelog.d/11212.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where messages in the `device_inbox` table for deleted devices would persist indefinitely. Contributed by @dklimpel and @JohannesKleine. \ No newline at end of file diff --git a/changelog.d/11213.removal b/changelog.d/11213.removal deleted file mode 100644 index 9e5ec936e3..0000000000 --- a/changelog.d/11213.removal +++ /dev/null @@ -1 +0,0 @@ -Remove deprecated admin API to delete rooms (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). \ No newline at end of file diff --git a/changelog.d/11217.bugfix b/changelog.d/11217.bugfix deleted file mode 100644 index 67ebb0d0e3..0000000000 --- a/changelog.d/11217.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in 1.35.0 which made it impossible to join rooms that return a `send_join` response containing floats. \ No newline at end of file diff --git a/changelog.d/11221.doc b/changelog.d/11221.doc deleted file mode 100644 index 17010bac8b..0000000000 --- a/changelog.d/11221.doc +++ /dev/null @@ -1 +0,0 @@ -Improve code formatting and fix a few typos in docs. Contributed by @sumnerevans at Beeper. diff --git a/changelog.d/11225.misc b/changelog.d/11225.misc deleted file mode 100644 index f14f65f9d4..0000000000 --- a/changelog.d/11225.misc +++ /dev/null @@ -1 +0,0 @@ -Replace outdated links in the pull request checklist with links to the rendered documentation. diff --git a/changelog.d/11226.misc b/changelog.d/11226.misc deleted file mode 100644 index 9ed4760ae0..0000000000 --- a/changelog.d/11226.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a bug in unit test `test_block_room_and_not_purge`. diff --git a/changelog.d/11228.feature b/changelog.d/11228.feature deleted file mode 100644 index 33c1756b50..0000000000 --- a/changelog.d/11228.feature +++ /dev/null @@ -1 +0,0 @@ -Allow the admin [Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api) to block a room without the need to join it. diff --git a/changelog.d/11229.misc b/changelog.d/11229.misc deleted file mode 100644 index 7bb01cf079..0000000000 --- a/changelog.d/11229.misc +++ /dev/null @@ -1 +0,0 @@ -`ObservableDeferred`: run registered observers in order. diff --git a/changelog.d/11231.misc b/changelog.d/11231.misc deleted file mode 100644 index c7fca7071e..0000000000 --- a/changelog.d/11231.misc +++ /dev/null @@ -1 +0,0 @@ -Minor speed up to start up times and getting updates for groups by adding missing index to `local_group_updates.stream_id`. diff --git a/changelog.d/11233.misc b/changelog.d/11233.misc deleted file mode 100644 index fdf9e5642e..0000000000 --- a/changelog.d/11233.misc +++ /dev/null @@ -1 +0,0 @@ -Add `twine` and `towncrier` as dev dependencies, as they're used by the release script. diff --git a/changelog.d/11234.bugfix b/changelog.d/11234.bugfix deleted file mode 100644 index c0c02a58f6..0000000000 --- a/changelog.d/11234.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix long-standing bug where cross signing keys were not included in the response to `/r0/keys/query` the first time a remote user was queried. diff --git a/changelog.d/11236.feature b/changelog.d/11236.feature deleted file mode 100644 index e7aeee2aa6..0000000000 --- a/changelog.d/11236.feature +++ /dev/null @@ -1 +0,0 @@ -Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/changelog.d/11237.misc b/changelog.d/11237.misc deleted file mode 100644 index b90efc6535..0000000000 --- a/changelog.d/11237.misc +++ /dev/null @@ -1 +0,0 @@ -Allow `stream_writers.typing` config to be a list of one worker. diff --git a/changelog.d/11239.misc b/changelog.d/11239.misc deleted file mode 100644 index 48a796bed0..0000000000 --- a/changelog.d/11239.misc +++ /dev/null @@ -1 +0,0 @@ -Remove debugging statement in tests. diff --git a/changelog.d/11240.bugfix b/changelog.d/11240.bugfix deleted file mode 100644 index 94d73f67e3..0000000000 --- a/changelog.d/11240.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection. diff --git a/changelog.d/11242.misc b/changelog.d/11242.misc deleted file mode 100644 index 3a98259edf..0000000000 --- a/changelog.d/11242.misc +++ /dev/null @@ -1 +0,0 @@ -Split out federated PDU retrieval function into a non-cached version. diff --git a/changelog.d/11244.misc b/changelog.d/11244.misc deleted file mode 100644 index c6e65df97f..0000000000 --- a/changelog.d/11244.misc +++ /dev/null @@ -1 +0,0 @@ -Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical messages backfilling in random order on remote homeservers. diff --git a/changelog.d/11246.misc b/changelog.d/11246.misc deleted file mode 100644 index e5e912c1b0..0000000000 --- a/changelog.d/11246.misc +++ /dev/null @@ -1 +0,0 @@ -Add an additional test for the `cachedList` method decorator. diff --git a/changelog.d/11247.misc b/changelog.d/11247.misc deleted file mode 100644 index 5ce701560e..0000000000 --- a/changelog.d/11247.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up code relating to to-device messages and sending ephemeral events to application services. \ No newline at end of file diff --git a/changelog.d/11253.misc b/changelog.d/11253.misc deleted file mode 100644 index 71c55a2751..0000000000 --- a/changelog.d/11253.misc +++ /dev/null @@ -1 +0,0 @@ -Make minor correction to the type of `auth_checkers` callbacks. diff --git a/changelog.d/11255.bugfix b/changelog.d/11255.bugfix deleted file mode 100644 index ce72592624..0000000000 --- a/changelog.d/11255.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix rolling back Synapse version when using workers. diff --git a/changelog.d/11257.doc b/changelog.d/11257.doc deleted file mode 100644 index 1205be2add..0000000000 --- a/changelog.d/11257.doc +++ /dev/null @@ -1 +0,0 @@ -Add documentation for using LemonLDAP as an OpenID Connect Identity Provider. Contributed by @l00ptr. diff --git a/changelog.d/11262.bugfix b/changelog.d/11262.bugfix deleted file mode 100644 index 768fbb8973..0000000000 --- a/changelog.d/11262.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where if a remote event is being processed by a worker when it gets killed then it won't get processed on restart. Introduced in v1.37.1. diff --git a/changelog.d/11263.feature b/changelog.d/11263.feature deleted file mode 100644 index 831e76ec9f..0000000000 --- a/changelog.d/11263.feature +++ /dev/null @@ -1 +0,0 @@ -Add some background update admin APIs. diff --git a/changelog.d/11269.misc b/changelog.d/11269.misc deleted file mode 100644 index a2149c2d2d..0000000000 --- a/changelog.d/11269.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up trivial aspects of the Debian package build tooling. diff --git a/changelog.d/11270.misc b/changelog.d/11270.misc deleted file mode 100644 index e2181b9b2a..0000000000 --- a/changelog.d/11270.misc +++ /dev/null @@ -1 +0,0 @@ -Blacklist new SyTest that checks that key uploads are valid pending the validation being implemented in Synapse. diff --git a/changelog.d/11273.misc b/changelog.d/11273.misc deleted file mode 100644 index a2149c2d2d..0000000000 --- a/changelog.d/11273.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up trivial aspects of the Debian package build tooling. diff --git a/changelog.d/11276.bugfix b/changelog.d/11276.bugfix deleted file mode 100644 index ce72592624..0000000000 --- a/changelog.d/11276.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix rolling back Synapse version when using workers. diff --git a/changelog.d/11278.misc b/changelog.d/11278.misc deleted file mode 100644 index 9b014bc8b4..0000000000 --- a/changelog.d/11278.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a small typo in the error response when a relation type other than 'm.annotation' is passed to `GET /rooms/{room_id}/aggregations/{event_id}`. \ No newline at end of file diff --git a/changelog.d/11282.misc b/changelog.d/11282.misc deleted file mode 100644 index 4720519cbc..0000000000 --- a/changelog.d/11282.misc +++ /dev/null @@ -1 +0,0 @@ -Require all files in synapse/ and tests/ to pass mypy unless specifically excluded. diff --git a/changelog.d/11285.misc b/changelog.d/11285.misc deleted file mode 100644 index 4720519cbc..0000000000 --- a/changelog.d/11285.misc +++ /dev/null @@ -1 +0,0 @@ -Require all files in synapse/ and tests/ to pass mypy unless specifically excluded. diff --git a/changelog.d/11286.doc b/changelog.d/11286.doc deleted file mode 100644 index 890d7b4ee4..0000000000 --- a/changelog.d/11286.doc +++ /dev/null @@ -1 +0,0 @@ -Fix typo in the word `available` and fix HTTP method (should be `GET`) for the `username_available` admin API. Contributed by Stanislav Motylkov. diff --git a/changelog.d/11331.misc b/changelog.d/11331.misc new file mode 100644 index 0000000000..1ab3a6a975 --- /dev/null +++ b/changelog.d/11331.misc @@ -0,0 +1 @@ +A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. diff --git a/changelog.d/11427.doc b/changelog.d/11427.doc new file mode 100644 index 0000000000..01cdfcf2b7 --- /dev/null +++ b/changelog.d/11427.doc @@ -0,0 +1 @@ +Document the usage of refresh tokens. \ No newline at end of file diff --git a/changelog.d/11520.misc b/changelog.d/11520.misc new file mode 100644 index 0000000000..2d84120e19 --- /dev/null +++ b/changelog.d/11520.misc @@ -0,0 +1 @@ +Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. \ No newline at end of file diff --git a/changelog.d/11531.misc b/changelog.d/11531.misc new file mode 100644 index 0000000000..ed6ef3bb3e --- /dev/null +++ b/changelog.d/11531.misc @@ -0,0 +1 @@ +Add a receipt types constant for `m.read`. diff --git a/changelog.d/11535.misc b/changelog.d/11535.misc new file mode 100644 index 0000000000..580ac354ab --- /dev/null +++ b/changelog.d/11535.misc @@ -0,0 +1 @@ +Clean up `synapse.rest.admin`. \ No newline at end of file diff --git a/changelog.d/11536.misc b/changelog.d/11536.misc new file mode 100644 index 0000000000..b9191c111b --- /dev/null +++ b/changelog.d/11536.misc @@ -0,0 +1 @@ +Improvements to log messages around handling stream ids. diff --git a/debian/changelog b/debian/changelog index 74a98f0866..acc9f6049e 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,12 +1,52 @@ -matrix-synapse-py3 (1.47.0+nmu1) UNRELEASED; urgency=medium +matrix-synapse-py3 (1.49.0~rc1) stable; urgency=medium + * New synapse release 1.49.0~rc1. + + -- Synapse Packaging team <packages@matrix.org> Tue, 07 Dec 2021 13:52:21 +0000 + +matrix-synapse-py3 (1.48.0) stable; urgency=medium + + * New synapse release 1.48.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 30 Nov 2021 11:24:15 +0000 + +matrix-synapse-py3 (1.48.0~rc1) stable; urgency=medium + + * New synapse release 1.48.0~rc1. + + -- Synapse Packaging team <packages@matrix.org> Thu, 25 Nov 2021 15:56:03 +0000 + +matrix-synapse-py3 (1.47.1) stable; urgency=medium + + * New synapse release 1.47.1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 19 Nov 2021 13:44:32 +0000 + +matrix-synapse-py3 (1.47.0) stable; urgency=medium + + * New synapse release 1.47.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 17 Nov 2021 13:09:43 +0000 + +matrix-synapse-py3 (1.47.0~rc3) stable; urgency=medium + + * New synapse release 1.47.0~rc3. + + -- Synapse Packaging team <packages@matrix.org> Tue, 16 Nov 2021 14:32:47 +0000 + +matrix-synapse-py3 (1.47.0~rc2) stable; urgency=medium + + [ Dan Callahan ] * Update scripts to pass Shellcheck lints. * Remove unused Vagrant scripts from debian/ directory. * Allow building Debian packages for any architecture, not just amd64. * Preinstall the "wheel" package when building virtualenvs. * Do not error if /etc/default/matrix-synapse is missing. - -- Dan Callahan <danc@element.io> Fri, 22 Oct 2021 22:20:31 +0000 + [ Synapse Packaging team ] + * New synapse release 1.47.0~rc2. + + -- Synapse Packaging team <packages@matrix.org> Wed, 10 Nov 2021 09:41:01 +0000 matrix-synapse-py3 (1.46.0) stable; urgency=medium diff --git a/docker/Dockerfile-workers b/docker/Dockerfile-workers index 969cf97286..46f2e17382 100644 --- a/docker/Dockerfile-workers +++ b/docker/Dockerfile-workers @@ -21,3 +21,6 @@ VOLUME ["/data"] # files to run the desired worker configuration. Will start supervisord. COPY ./docker/configure_workers_and_start.py /configure_workers_and_start.py ENTRYPOINT ["/configure_workers_and_start.py"] + +HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \ + CMD /bin/sh /healthcheck.sh diff --git a/docker/conf-workers/healthcheck.sh.j2 b/docker/conf-workers/healthcheck.sh.j2 new file mode 100644 index 0000000000..79c621f89c --- /dev/null +++ b/docker/conf-workers/healthcheck.sh.j2 @@ -0,0 +1,6 @@ +#!/bin/sh +# This healthcheck script is designed to return OK when every +# host involved returns OK +{%- for healthcheck_url in healthcheck_urls %} +curl -fSs {{ healthcheck_url }} || exit 1 +{%- endfor %} diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml index 3cba594d02..f10f78a48c 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml @@ -148,14 +148,6 @@ bcrypt_rounds: 12 allow_guest_access: {{ "True" if SYNAPSE_ALLOW_GUEST else "False" }} enable_group_creation: true -# The list of identity servers trusted to verify third party -# identifiers by this server. -# -# Also defines the ID server which will be called when an account is -# deactivated (one will be picked arbitrarily). -trusted_third_party_id_servers: - - matrix.org - - vector.im ## Metrics ### diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index efb9476cd6..adbb551cee 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -48,7 +48,7 @@ WORKERS_CONFIG = { "app": "synapse.app.user_dir", "listener_resources": ["client"], "endpoint_patterns": [ - "^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$" + "^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$" ], "shared_extra_conf": {"update_user_directory": False}, "worker_extra_conf": "", @@ -85,10 +85,10 @@ WORKERS_CONFIG = { "app": "synapse.app.generic_worker", "listener_resources": ["client"], "endpoint_patterns": [ - "^/_matrix/client/(v2_alpha|r0)/sync$", - "^/_matrix/client/(api/v1|v2_alpha|r0)/events$", - "^/_matrix/client/(api/v1|r0)/initialSync$", - "^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$", + "^/_matrix/client/(v2_alpha|r0|v3)/sync$", + "^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$", + "^/_matrix/client/(api/v1|r0|v3)/initialSync$", + "^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$", ], "shared_extra_conf": {}, "worker_extra_conf": "", @@ -146,11 +146,11 @@ WORKERS_CONFIG = { "app": "synapse.app.generic_worker", "listener_resources": ["client"], "endpoint_patterns": [ - "^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact", - "^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send", - "^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$", - "^/_matrix/client/(api/v1|r0|unstable)/join/", - "^/_matrix/client/(api/v1|r0|unstable)/profile/", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/redact", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/send", + "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$", + "^/_matrix/client/(api/v1|r0|v3|unstable)/join/", + "^/_matrix/client/(api/v1|r0|v3|unstable)/profile/", ], "shared_extra_conf": {}, "worker_extra_conf": "", @@ -158,7 +158,7 @@ WORKERS_CONFIG = { "frontend_proxy": { "app": "synapse.app.frontend_proxy", "listener_resources": ["client", "replication"], - "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|unstable)/keys/upload"], + "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"], "shared_extra_conf": {}, "worker_extra_conf": ( "worker_main_http_uri: http://127.0.0.1:%d" @@ -474,10 +474,16 @@ def generate_worker_files(environ, config_path: str, data_dir: str): # Determine the load-balancing upstreams to configure nginx_upstream_config = "" + + # At the same time, prepare a list of internal endpoints to healthcheck + # starting with the main process which exists even if no workers do. + healthcheck_urls = ["http://localhost:8080/health"] + for upstream_worker_type, upstream_worker_ports in nginx_upstreams.items(): body = "" for port in upstream_worker_ports: body += " server localhost:%d;\n" % (port,) + healthcheck_urls.append("http://localhost:%d/health" % (port,)) # Add to the list of configured upstreams nginx_upstream_config += NGINX_UPSTREAM_CONFIG_BLOCK.format( @@ -510,6 +516,13 @@ def generate_worker_files(environ, config_path: str, data_dir: str): worker_config=supervisord_config, ) + # healthcheck config + convert( + "/conf/healthcheck.sh.j2", + "/healthcheck.sh", + healthcheck_urls=healthcheck_urls, + ) + # Ensure the logging directory exists log_dir = data_dir + "/logs" if not os.path.exists(log_dir): diff --git a/docs/README.md b/docs/README.md index 6d70f5afff..5222ee5f03 100644 --- a/docs/README.md +++ b/docs/README.md @@ -50,8 +50,10 @@ build the documentation with: mdbook build ``` -The rendered contents will be outputted to a new `book/` directory at the root of the repository. You can -browse the book by opening `book/index.html` in a web browser. +The rendered contents will be outputted to a new `book/` directory at the root of the repository. Please note that +index.html is not built by default, it is created by copying over the file `welcome_and_overview.html` to `index.html` +during deployment. Thus, when running `mdbook serve` locally the book will initially show a 404 in place of the index +due to the above. Do not be alarmed! You can also have mdbook host the docs on a local webserver with hot-reload functionality via: diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 04320ab07b..11f597b3ed 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -23,13 +23,14 @@ - [Structured Logging](structured_logging.md) - [Templates](templates.md) - [User Authentication](usage/configuration/user_authentication/README.md) - - [Single-Sign On]() + - [Single-Sign On](usage/configuration/user_authentication/single_sign_on/README.md) - [OpenID Connect](openid.md) - - [SAML]() - - [CAS]() + - [SAML](usage/configuration/user_authentication/single_sign_on/saml.md) + - [CAS](usage/configuration/user_authentication/single_sign_on/cas.md) - [SSO Mapping Providers](sso_mapping_providers.md) - [Password Auth Providers](password_auth_providers.md) - [JSON Web Tokens](jwt.md) + - [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md) - [Registration Captcha](CAPTCHA_SETUP.md) - [Application Services](application_services.md) - [Server Notices](server_notices.md) @@ -44,6 +45,7 @@ - [Presence router callbacks](modules/presence_router_callbacks.md) - [Account validity callbacks](modules/account_validity_callbacks.md) - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md) + - [Background update controller callbacks](modules/background_update_controller_callbacks.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Workers](workers.md) - [Using `synctl` with Workers](synctl_workers.md) @@ -64,9 +66,15 @@ - [Statistics](admin_api/statistics.md) - [Users](admin_api/user_admin_api.md) - [Server Version](admin_api/version_api.md) + - [Federation](usage/administration/admin_api/federation.md) - [Manhole](manhole.md) - [Monitoring](metrics-howto.md) + - [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md) + - [Useful SQL for Admins](usage/administration/useful_sql_for_admins.md) + - [Database Maintenance Tools](usage/administration/database_maintenance_tools.md) + - [State Groups](usage/administration/state_groups.md) - [Request log format](usage/administration/request_log.md) + - [Admin FAQ](usage/administration/admin_faq.md) - [Scripts]() # Development @@ -94,3 +102,4 @@ # Other - [Dependency Deprecation Policy](deprecation_policy.md) + - [Running Synapse on a Single-Board Computer](other/running_synapse_on_single_board_computers.md) diff --git a/docs/admin_api/purge_history_api.md b/docs/admin_api/purge_history_api.md index bd29e29ab8..277e28d9cb 100644 --- a/docs/admin_api/purge_history_api.md +++ b/docs/admin_api/purge_history_api.md @@ -70,6 +70,8 @@ This API returns a JSON body like the following: The status will be one of `active`, `complete`, or `failed`. +If `status` is `failed` there will be a string `error` with the error message. + ## Reclaim disk space (Postgres) To reclaim the disk space and return it to the operating system, you need to run diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 41a4961d00..0f1a74134f 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -3,7 +3,11 @@ - [Room Details API](#room-details-api) - [Room Members API](#room-members-api) - [Room State API](#room-state-api) +- [Block Room API](#block-room-api) - [Delete Room API](#delete-room-api) + * [Version 1 (old version)](#version-1-old-version) + * [Version 2 (new version)](#version-2-new-version) + * [Status of deleting rooms](#status-of-deleting-rooms) * [Undoing room shutdowns](#undoing-room-shutdowns) - [Make Room Admin API](#make-room-admin-api) - [Forward Extremities Admin API](#forward-extremities-admin-api) @@ -383,6 +387,83 @@ A response body like the following is returned: } ``` +# Block Room API +The Block Room admin API allows server admins to block and unblock rooms, +and query to see if a given room is blocked. +This API can be used to pre-emptively block a room, even if it's unknown to this +homeserver. Users will be prevented from joining a blocked room. + +## Block or unblock a room + +The API is: + +``` +PUT /_synapse/admin/v1/rooms/<room_id>/block +``` + +with a body of: + +```json +{ + "block": true +} +``` + +A response body like the following is returned: + +```json +{ + "block": true +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `room_id` - The ID of the room. + +The following JSON body parameters are available: + +- `block` - If `true` the room will be blocked and if `false` the room will be unblocked. + +**Response** + +The following fields are possible in the JSON response body: + +- `block` - A boolean. `true` if the room is blocked, otherwise `false` + +## Get block status + +The API is: + +``` +GET /_synapse/admin/v1/rooms/<room_id>/block +``` + +A response body like the following is returned: + +```json +{ + "block": true, + "user_id": "<user_id>" +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +- `room_id` - The ID of the room. + +**Response** + +The following fields are possible in the JSON response body: + +- `block` - A boolean. `true` if the room is blocked, otherwise `false` +- `user_id` - An optional string. If the room is blocked (`block` is `true`) shows + the user who has add the room to blocking list. Otherwise it is not displayed. + # Delete Room API The Delete Room admin API allows server admins to remove rooms from the server @@ -397,10 +478,10 @@ as room administrator and will contain a message explaining what happened. Users to the new room will have power level `-10` by default, and thus be unable to speak. If `block` is `true`, users will be prevented from joining the old room. -This option can also be used to pre-emptively block a room, even if it's unknown -to this homeserver. In this case, the room will be blocked, and no further action -will be taken. If `block` is `false`, attempting to delete an unknown room is -invalid and will be rejected as a bad request. +This option can in [Version 1](#version-1-old-version) also be used to pre-emptively +block a room, even if it's unknown to this homeserver. In this case, the room will be +blocked, and no further action will be taken. If `block` is `false`, attempting to +delete an unknown room is invalid and will be rejected as a bad request. This API will remove all trace of the old room from your database after removing all local users. If `purge` is `true` (the default), all traces of the old room will @@ -412,6 +493,17 @@ several minutes or longer. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see [Admin API](../usage/administration/admin_api). + +## Version 1 (old version) + +This version works synchronously. That means you only get the response once the server has +finished the action, which may take a long time. If you request the same action +a second time, and the server has not finished the first one, the second request will block. +This is fixed in version 2 of this API. The parameters are the same in both APIs. +This API will become deprecated in the future. + The API is: ``` @@ -430,9 +522,6 @@ with a body of: } ``` -To use it, you will need to authenticate by providing an ``access_token`` for a -server admin: see [Admin API](../usage/administration/admin_api). - A response body like the following is returned: ```json @@ -449,6 +538,44 @@ A response body like the following is returned: } ``` +The parameters and response values have the same format as +[version 2](#version-2-new-version) of the API. + +## Version 2 (new version) + +**Note**: This API is new, experimental and "subject to change". + +This version works asynchronously, meaning you get the response from server immediately +while the server works on that task in background. You can then request the status of the action +to check if it has completed. + +The API is: + +``` +DELETE /_synapse/admin/v2/rooms/<room_id> +``` + +with a body of: + +```json +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", + "block": true, + "purge": true +} +``` + +The API starts the shut down and purge running, and returns immediately with a JSON body with +a purge id: + +```json +{ + "delete_id": "<opaque id>" +} +``` + **Parameters** The following parameters should be set in the URL: @@ -470,7 +597,8 @@ The following JSON body parameters are available: is not permitted and rooms in violation will be blocked.` * `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to join the room. Rooms can be blocked - even if they're not yet known to the homeserver. Defaults to `false`. + even if they're not yet known to the homeserver (only with + [Version 1](#version-1-old-version) of the API). Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. * `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it @@ -480,17 +608,124 @@ The following JSON body parameters are available: The JSON body must not be empty. The body must be at least `{}`. -**Response** +## Status of deleting rooms -The following fields are returned in the JSON response body: +**Note**: This API is new, experimental and "subject to change". + +It is possible to query the status of the background task for deleting rooms. +The status can be queried up to 24 hours after completion of the task, +or until Synapse is restarted (whichever happens first). + +### Query by `room_id` + +With this API you can get the status of all active deletion tasks, and all those completed in the last 24h, +for the given `room_id`. + +The API is: + +``` +GET /_synapse/admin/v2/rooms/<room_id>/delete_status +``` + +A response body like the following is returned: -* `kicked_users` - An array of users (`user_id`) that were kicked. -* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. -* `local_aliases` - An array of strings representing the local aliases that were migrated from - the old room to the new. -* `new_room_id` - A string representing the room ID of the new room, or `null` if - no such room was created. +```json +{ + "results": [ + { + "delete_id": "delete_id1", + "status": "failed", + "error": "error message", + "shutdown_room": { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": null + } + }, { + "delete_id": "delete_id2", + "status": "purging", + "shutdown_room": { + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" + } + } + ] +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +* `room_id` - The ID of the room. + +### Query by `delete_id` + +With this API you can get the status of one specific task by `delete_id`. + +The API is: + +``` +GET /_synapse/admin/v2/rooms/delete_status/<delete_id> +``` + +A response body like the following is returned: + +```json +{ + "status": "purging", + "shutdown_room": { + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" + } +} +``` + +**Parameters** + +The following parameters should be set in the URL: + +* `delete_id` - The ID for this delete. + +### Response + +The following fields are returned in the JSON response body: +- `results` - An array of objects, each containing information about one task. + This field is omitted from the result when you query by `delete_id`. + Task objects contain the following fields: + - `delete_id` - The ID for this purge if you query by `room_id`. + - `status` - The status will be one of: + - `shutting_down` - The process is removing users from the room. + - `purging` - The process is purging the room and event data from database. + - `complete` - The process has completed successfully. + - `failed` - The process is aborted, an error has occurred. + - `error` - A string that shows an error message if `status` is `failed`. + Otherwise this field is hidden. + - `shutdown_room` - An object containing information about the result of shutting down the room. + *Note:* The result is shown after removing the room members. + The delete process can still be running. Please pay attention to the `status`. + - `kicked_users` - An array of users (`user_id`) that were kicked. + - `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. + - `local_aliases` - An array of strings representing the local aliases that were + migrated from the old room to the new. + - `new_room_id` - A string representing the room ID of the new room, or `null` if + no such room was created. ## Undoing room deletions diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 16ec33b3c1..ba574d795f 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -948,7 +948,7 @@ The following fields are returned in the JSON response body: See also the [Client-Server API Spec on pushers](https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers). -## Shadow-banning users +## Controlling whether a user is shadow-banned Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. A shadow-banned users receives successful responses to their client-server API requests, @@ -961,16 +961,22 @@ or broken behaviour for the client. A shadow-banned user will not receive any notification and it is generally more appropriate to ban or kick abusive users. A shadow-banned user will be unable to contact anyone on the server. -The API is: +To shadow-ban a user the API is: ``` POST /_synapse/admin/v1/users/<user_id>/shadow_ban ``` +To un-shadow-ban a user the API is: + +``` +DELETE /_synapse/admin/v1/users/<user_id>/shadow_ban +``` + To use it, you will need to authenticate by providing an `access_token` for a server admin: [Admin API](../usage/administration/admin_api) -An empty JSON dict is returned. +An empty JSON dict is returned in both cases. **Parameters** diff --git a/docs/ancient_architecture_notes.md b/docs/ancient_architecture_notes.md index 3ea8976cc7..07bb199d7a 100644 --- a/docs/ancient_architecture_notes.md +++ b/docs/ancient_architecture_notes.md @@ -7,7 +7,7 @@ ## Server to Server Stack -To use the server to server stack, home servers should only need to +To use the server to server stack, homeservers should only need to interact with the Messaging layer. The server to server side of things is designed into 4 distinct layers: @@ -23,7 +23,7 @@ Server with a domain specific API. 1. **Messaging Layer** - This is what the rest of the Home Server hits to send messages, join rooms, + This is what the rest of the homeserver hits to send messages, join rooms, etc. It also allows you to register callbacks for when it get's notified by lower levels that e.g. a new message has been received. @@ -45,7 +45,7 @@ Server with a domain specific API. For incoming PDUs, it has to check the PDUs it references to see if we have missed any. If we have go and ask someone (another - home server) for it. + homeserver) for it. 3. **Transaction Layer** diff --git a/docs/development/room-dag-concepts.md b/docs/development/room-dag-concepts.md index 5eed72bec6..cbc7cf2949 100644 --- a/docs/development/room-dag-concepts.md +++ b/docs/development/room-dag-concepts.md @@ -38,16 +38,15 @@ Most-recent-in-time events in the DAG which are not referenced by any other even The forward extremities of a room are used as the `prev_events` when the next event is sent. -## Backwards extremity +## Backward extremity The current marker of where we have backfilled up to and will generally be the -oldest-in-time events we know of in the DAG. +`prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when +backfilling history. -This is an event where we haven't fetched all of the `prev_events` for. - -Once we have fetched all of its `prev_events`, it's unmarked as a backwards -extremity (although we may have formed new backwards extremities from the prev -events during the backfilling process). +When we persist a non-outlier event, we clear it as a backward extremity and set +all of its `prev_events` as the new backward extremities if they aren't already +persisted in the `events` table. ## Outliers @@ -56,8 +55,7 @@ We mark an event as an `outlier` when we haven't figured out the state for the room at that point in the DAG yet. We won't *necessarily* have the `prev_events` of an `outlier` in the database, -but it's entirely possible that we *might*. The status of whether we have all of -the `prev_events` is marked as a [backwards extremity](#backwards-extremity). +but it's entirely possible that we *might*. For example, when we fetch the event auth chain or state for a given event, we mark all of those claimed auth events as outliers because we haven't done the diff --git a/docs/jwt.md b/docs/jwt.md index 5be9fd26e3..32f58cc0cb 100644 --- a/docs/jwt.md +++ b/docs/jwt.md @@ -22,8 +22,9 @@ will be removed in a future version of Synapse. The `token` field should include the JSON web token with the following claims: -* The `sub` (subject) claim is required and should encode the local part of the - user ID. +* A claim that encodes the local part of the user ID is required. By default, + the `sub` (subject) claim is used, or a custom claim can be set in the + configuration file. * The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`) claims are optional, but validated if present. * The issuer (`iss`) claim is optional, but required and validated if configured. diff --git a/docs/media_repository.md b/docs/media_repository.md index 99ee8f1ef7..ba17f8a856 100644 --- a/docs/media_repository.md +++ b/docs/media_repository.md @@ -2,29 +2,80 @@ *Synapse implementation-specific details for the media repository* -The media repository is where attachments and avatar photos are stored. -It stores attachment content and thumbnails for media uploaded by local users. -It caches attachment content and thumbnails for media uploaded by remote users. +The media repository + * stores avatars, attachments and their thumbnails for media uploaded by local + users. + * caches avatars, attachments and their thumbnails for media uploaded by remote + users. + * caches resources and thumbnails used for + [URL previews](development/url_previews.md). -## Storage +All media in Matrix can be identified by a unique +[MXC URI](https://spec.matrix.org/latest/client-server-api/#matrix-content-mxc-uris), +consisting of a server name and media ID: +``` +mxc://<server-name>/<media-id> +``` -Each item of media is assigned a `media_id` when it is uploaded. -The `media_id` is a randomly chosen, URL safe 24 character string. +## Local Media +Synapse generates 24 character media IDs for content uploaded by local users. +These media IDs consist of upper and lowercase letters and are case-sensitive. +Other homeserver implementations may generate media IDs differently. -Metadata such as the MIME type, upload time and length are stored in the -sqlite3 database indexed by `media_id`. +Local media is recorded in the `local_media_repository` table, which includes +metadata such as MIME types, upload times and file sizes. +Note that this table is shared by the URL cache, which has a different media ID +scheme. -Content is stored on the filesystem under a `"local_content"` directory. +### Paths +A file with media ID `aabbcccccccccccccccccccc` and its `128x96` `image/jpeg` +thumbnail, created by scaling, would be stored at: +``` +local_content/aa/bb/cccccccccccccccccccc +local_thumbnails/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale +``` -Thumbnails are stored under a `"local_thumbnails"` directory. +## Remote Media +When media from a remote homeserver is requested from Synapse, it is assigned +a local `filesystem_id`, with the same format as locally-generated media IDs, +as described above. -The item with `media_id` `"aabbccccccccdddddddddddd"` is stored under -`"local_content/aa/bb/ccccccccdddddddddddd"`. Its thumbnail with width -`128` and height `96` and type `"image/jpeg"` is stored under -`"local_thumbnails/aa/bb/ccccccccdddddddddddd/128-96-image-jpeg"` +A record of remote media is stored in the `remote_media_cache` table, which +can be used to map remote MXC URIs (server names and media IDs) to local +`filesystem_id`s. -Remote content is cached under `"remote_content"` directory. Each item of -remote content is assigned a local `"filesystem_id"` to ensure that the -directory structure `"remote_content/server_name/aa/bb/ccccccccdddddddddddd"` -is appropriate. Thumbnails for remote content are stored under -`"remote_thumbnail/server_name/..."` +### Paths +A file from `matrix.org` with `filesystem_id` `aabbcccccccccccccccccccc` and its +`128x96` `image/jpeg` thumbnail, created by scaling, would be stored at: +``` +remote_content/matrix.org/aa/bb/cccccccccccccccccccc +remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg-scale +``` +Older thumbnails may omit the thumbnailing method: +``` +remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg +``` + +Note that `remote_thumbnail/` does not have an `s`. + +## URL Previews +See [URL Previews](development/url_previews.md) for documentation on the URL preview +process. + +When generating previews for URLs, Synapse may download and cache various +resources, including images. These resources are assigned temporary media IDs +of the form `yyyy-mm-dd_aaaaaaaaaaaaaaaa`, where `yyyy-mm-dd` is the current +date and `aaaaaaaaaaaaaaaa` is a random sequence of 16 case-sensitive letters. + +The metadata for these cached resources is stored in the +`local_media_repository` and `local_media_repository_url_cache` tables. + +Resources for URL previews are deleted after a few days. + +### Paths +The file with media ID `yyyy-mm-dd_aaaaaaaaaaaaaaaa` and its `128x96` +`image/jpeg` thumbnail, created by scaling, would be stored at: +``` +url_cache/yyyy-mm-dd/aaaaaaaaaaaaaaaa +url_cache_thumbnails/yyyy-mm-dd/aaaaaaaaaaaaaaaa/128-96-image-jpeg-scale +``` diff --git a/docs/modules/background_update_controller_callbacks.md b/docs/modules/background_update_controller_callbacks.md new file mode 100644 index 0000000000..b3e7c259f4 --- /dev/null +++ b/docs/modules/background_update_controller_callbacks.md @@ -0,0 +1,71 @@ +# Background update controller callbacks + +Background update controller callbacks allow module developers to control (e.g. rate-limit) +how database background updates are run. A database background update is an operation +Synapse runs on its database in the background after it starts. It's usually used to run +database operations that would take too long if they were run at the same time as schema +updates (which are run on startup) and delay Synapse's startup too much: populating a +table with a big amount of data, adding an index on a big table, deleting superfluous data, +etc. + +Background update controller callbacks can be registered using the module API's +`register_background_update_controller_callbacks` method. Only the first module (in order +of appearance in Synapse's configuration file) calling this method can register background +update controller callbacks, subsequent calls are ignored. + +The available background update controller callbacks are: + +### `on_update` + +_First introduced in Synapse v1.49.0_ + +```python +def on_update(update_name: str, database_name: str, one_shot: bool) -> AsyncContextManager[int] +``` + +Called when about to do an iteration of a background update. The module is given the name +of the update, the name of the database, and a flag to indicate whether the background +update will happen in one go and may take a long time (e.g. creating indices). If this last +argument is set to `False`, the update will be run in batches. + +The module must return an async context manager. It will be entered before Synapse runs a +background update; this should return the desired duration of the iteration, in +milliseconds. + +The context manager will be exited when the iteration completes. Note that the duration +returned by the context manager is a target, and an iteration may take substantially longer +or shorter. If the `one_shot` flag is set to `True`, the duration returned is ignored. + +__Note__: Unlike most module callbacks in Synapse, this one is _synchronous_. This is +because asynchronous operations are expected to be run by the async context manager. + +This callback is required when registering any other background update controller callback. + +### `default_batch_size` + +_First introduced in Synapse v1.49.0_ + +```python +async def default_batch_size(update_name: str, database_name: str) -> int +``` + +Called before the first iteration of a background update, with the name of the update and +of the database. The module must return the number of elements to process in this first +iteration. + +If this callback is not defined, Synapse will use a default value of 100. + +### `min_batch_size` + +_First introduced in Synapse v1.49.0_ + +```python +async def min_batch_size(update_name: str, database_name: str) -> int +``` + +Called before running a new batch for a background update, with the name of the update and +of the database. The module must return an integer representing the minimum number of +elements to process in this iteration. This number must be at least 1, and is used to +ensure that progress is always made. + +If this callback is not defined, Synapse will use a default value of 100. diff --git a/docs/modules/writing_a_module.md b/docs/modules/writing_a_module.md index 7764e06692..e7c0ffad58 100644 --- a/docs/modules/writing_a_module.md +++ b/docs/modules/writing_a_module.md @@ -71,15 +71,15 @@ Modules **must** register their web resources in their `__init__` method. ## Registering a callback Modules can use Synapse's module API to register callbacks. Callbacks are functions that -Synapse will call when performing specific actions. Callbacks must be asynchronous, and -are split in categories. A single module may implement callbacks from multiple categories, -and is under no obligation to implement all callbacks from the categories it registers -callbacks for. +Synapse will call when performing specific actions. Callbacks must be asynchronous (unless +specified otherwise), and are split in categories. A single module may implement callbacks +from multiple categories, and is under no obligation to implement all callbacks from the +categories it registers callbacks for. Modules can register callbacks using one of the module API's `register_[...]_callbacks` methods. The callback functions are passed to these methods as keyword arguments, with -the callback name as the argument name and the function as its value. This is demonstrated -in the example below. A `register_[...]_callbacks` method exists for each category. +the callback name as the argument name and the function as its value. A +`register_[...]_callbacks` method exists for each category. Callbacks for each category can be found on their respective page of the [Synapse documentation website](https://matrix-org.github.io/synapse). \ No newline at end of file diff --git a/docs/openid.md b/docs/openid.md index c74e8bda60..ff9de9d5b8 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -83,7 +83,7 @@ oidc_providers: ### Dex -[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. +[Dex][dex-idp] is a simple, open-source OpenID Connect Provider. Although it is designed to help building a full-blown provider with an external database, it can be configured with static passwords in a config file. @@ -523,7 +523,7 @@ The synapse config will look like this: email_template: "{{ user.email }}" ``` -## Django OAuth Toolkit +### Django OAuth Toolkit [django-oauth-toolkit](https://github.com/jazzband/django-oauth-toolkit) is a Django application providing out of the box all the endpoints, data and logic diff --git a/docs/other/running_synapse_on_single_board_computers.md b/docs/other/running_synapse_on_single_board_computers.md new file mode 100644 index 0000000000..ea14afa8b2 --- /dev/null +++ b/docs/other/running_synapse_on_single_board_computers.md @@ -0,0 +1,74 @@ +## Summary of performance impact of running on resource constrained devices such as SBCs + +I've been running my homeserver on a cubietruck at home now for some time and am often replying to statements like "you need loads of ram to join large rooms" with "it works fine for me". I thought it might be useful to curate a summary of the issues you're likely to run into to help as a scaling-down guide, maybe highlight these for development work or end up as documentation. It seems that once you get up to about 4x1.5GHz arm64 4GiB these issues are no longer a problem. + +- **Platform**: 2x1GHz armhf 2GiB ram [Single-board computers](https://wiki.debian.org/CheapServerBoxHardware), SSD, postgres. + +### Presence + +This is the main reason people have a poor matrix experience on resource constrained homeservers. Element web will frequently be saying the server is offline while the python process will be pegged at 100% cpu. This feature is used to tell when other users are active (have a client app in the foreground) and therefore more likely to respond, but requires a lot of network activity to maintain even when nobody is talking in a room. + +![Screenshot_2020-10-01_19-29-46](https://user-images.githubusercontent.com/71895/94848963-a47a3580-041c-11eb-8b6e-acb772b4259e.png) + +While synapse does have some performance issues with presence [#3971](https://github.com/matrix-org/synapse/issues/3971), the fundamental problem is that this is an easy feature to implement for a centralised service at nearly no overhead, but federation makes it combinatorial [#8055](https://github.com/matrix-org/synapse/issues/8055). There is also a client-side config option which disables the UI and idle tracking [enable_presence_by_hs_url] to blacklist the largest instances but I didn't notice much difference, so I recommend disabling the feature entirely at the server level as well. + +[enable_presence_by_hs_url]: https://github.com/vector-im/element-web/blob/v1.7.8/config.sample.json#L45 + +### Joining + +Joining a "large", federated room will initially fail with the below message in Element web, but waiting a while (10-60mins) and trying again will succeed without any issue. What counts as "large" is not message history, user count, connections to homeservers or even a simple count of the state events, it is instead how long the state resolution algorithm takes. However, each of those numbers are reasonable proxies, so we can use them as estimates since user count is one of the few things you see before joining. + +![Screenshot_2020-10-02_17-15-06](https://user-images.githubusercontent.com/71895/94945781-18771500-04d3-11eb-8419-83c2da73a341.png) + +This is [#1211](https://github.com/matrix-org/synapse/issues/1211) and will also hopefully be mitigated by peeking [matrix-org/matrix-doc#2753](https://github.com/matrix-org/matrix-doc/pull/2753) so at least you don't need to wait for a join to complete before finding out if it's the kind of room you want. Note that you should first disable presence, otherwise it'll just make the situation worse [#3120](https://github.com/matrix-org/synapse/issues/3120). There is a lot of database interaction too, so make sure you've [migrated your data](../postgres.md) from the default sqlite to postgresql. Personally, I recommend patience - once the initial join is complete there's rarely any issues with actually interacting with the room, but if you like you can just block "large" rooms entirely. + +### Sessions + +Anything that requires modifying the device list [#7721](https://github.com/matrix-org/synapse/issues/7721) will take a while to propagate, again taking the client "Offline" until it's complete. This includes signing in and out, editing the public name and verifying e2ee. The main mitigation I recommend is to keep long-running sessions open e.g. by using Firefox SSB "Use this site in App mode" or Chromium PWA "Install Element". + +### Recommended configuration + +Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml. + +``` +# Set to false to disable presence tracking on this homeserver. +use_presence: false + +# When this is enabled, the room "complexity" will be checked before a user +# joins a new remote room. If it is above the complexity limit, the server will +# disallow joining, or will instantly leave. +limit_remote_rooms: + # Uncomment to enable room complexity checking. + #enabled: true + complexity: 3.0 + +# Database configuration +database: + name: psycopg2 + args: + user: matrix-synapse + # Generate a long, secure one with a password manager + password: hunter2 + database: matrix-synapse + host: localhost + cp_min: 5 + cp_max: 10 +``` + +Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this: + +``` +admin@homeserver:~$ zgrep '/client/r0/join/' /var/log/matrix-synapse/homeserver.log* | awk '{print $18, $25}' | sort --human-numeric-sort +29.922sec/-0.002sec /_matrix/client/r0/join/%23debian-fasttrack%3Apoddery.com +182.088sec/0.003sec /_matrix/client/r0/join/%23decentralizedweb-general%3Amatrix.org +911.625sec/-570.847sec /_matrix/client/r0/join/%23synapse%3Amatrix.org + +admin@homeserver:~$ sudo --user postgres psql matrix-synapse --command 'select canonical_alias, joined_members, current_state_events from room_stats_state natural join room_stats_current where canonical_alias is not null order by current_state_events desc fetch first 5 rows only' + canonical_alias | joined_members | current_state_events +-------------------------------+----------------+---------------------- + #_oftc_#debian:matrix.org | 871 | 52355 + #matrix:matrix.org | 6379 | 10684 + #irc:matrix.org | 461 | 3751 + #decentralizedweb-general:matrix.org | 997 | 1509 + #whatsapp:maunium.net | 554 | 854 +``` \ No newline at end of file diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index d7beacfff3..dc0dfffa21 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -1,7 +1,7 @@ <h2 style="color:red"> This page of the Synapse documentation is now deprecated. For up to date documentation on setting up or writing a password auth provider module, please see -<a href="modules.md">this page</a>. +<a href="modules/index.md">this page</a>. </h2> # Password auth provider modules diff --git a/docs/postgres.md b/docs/postgres.md index 083b0aaff0..e4861c1f12 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -118,6 +118,9 @@ performance: Note that the appropriate values for those fields depend on the amount of free memory the database host has available. +Additionally, admins of large deployments might want to consider using huge pages +to help manage memory, especially when using large values of `shared_buffers`. You +can read more about that [here](https://www.postgresql.org/docs/10/kernel-resources.html#LINUX-HUGE-PAGES). ## Porting from SQLite diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d48c08f1d9..6696ed5d1e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -647,8 +647,8 @@ retention: # #federation_certificate_verification_whitelist: # - lon.example.com -# - *.domain.com -# - *.onion +# - "*.domain.com" +# - "*.onion" # List of custom certificate authorities for federation traffic. # @@ -1209,6 +1209,44 @@ oembed: # #session_lifetime: 24h +# Time that an access token remains valid for, if the session is +# using refresh tokens. +# For more information about refresh tokens, please see the manual. +# Note that this only applies to clients which advertise support for +# refresh tokens. +# +# Note also that this is calculated at login time and refresh time: +# changes are not applied to existing sessions until they are refreshed. +# +# By default, this is 5 minutes. +# +#refreshable_access_token_lifetime: 5m + +# Time that a refresh token remains valid for (provided that it is not +# exchanged for another one first). +# This option can be used to automatically log-out inactive sessions. +# Please see the manual for more information. +# +# Note also that this is calculated at login time and refresh time: +# changes are not applied to existing sessions until they are refreshed. +# +# By default, this is infinite. +# +#refresh_token_lifetime: 24h + +# Time that an access token remains valid for, if the session is NOT +# using refresh tokens. +# Please note that not all clients support refresh tokens, so setting +# this to a short value may be inconvenient for some users who will +# then be logged out frequently. +# +# Note also that this is calculated at login time: changes are not applied +# retrospectively to existing sessions for users that have already logged in. +# +# By default, this is infinite. +# +#nonrefreshable_access_token_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: @@ -2039,6 +2077,12 @@ sso: # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and @@ -2360,8 +2404,8 @@ user_directory: # indexes were (re)built was before Synapse 1.44, you'll have to # rebuild the indexes in order to search through all known users. # These indexes are built the first time Synapse starts; admins can - # manually trigger a rebuild following the instructions at - # https://matrix-org.github.io/synapse/latest/user_directory.html + # manually trigger a rebuild via API following the instructions at + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/docs/setup/installation.md b/docs/setup/installation.md index 3e08024441..16562be953 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -76,6 +76,12 @@ The fingerprint of the repository signing key (as shown by `gpg /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. +When installing with Debian packages, you might prefer to place files in +`/etc/matrix-synapse/conf.d/` to override your configuration without editing +the main configuration file at `/etc/matrix-synapse/homeserver.yaml`. +By doing that, you won't be asked if you want to replace your configuration +file when you upgrade the Debian package to a later version. + ##### Downstream Debian packages We do not recommend using the packages from the default Debian `buster` diff --git a/docs/templates.md b/docs/templates.md index a240f58b54..2b66e9d862 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -71,7 +71,12 @@ Below are the templates Synapse will look for when generating the content of an * `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's sender * `sender_hash`: a hash of the user ID of the sender + * `msgtype`: the type of the message + * `body_text_html`: html representation of the message + * `body_text_plain`: plaintext representation of the message + * `image_url`: mxc url of an image, when "msgtype" is "m.image" * `link`: a `matrix.to` link to the room + * `avator_url`: url to the room's avator * `reason`: information on the event that triggered the email to be sent. It's an object with the following attributes: * `room_id`: the ID of the room the event was sent in diff --git a/docs/turn-howto.md b/docs/turn-howto.md index 99f0bb2fc2..e6812de69e 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -1,12 +1,12 @@ # Overview -This document explains how to enable VoIP relaying on your Home Server with +This document explains how to enable VoIP relaying on your homeserver with TURN. -The synapse Matrix Home Server supports integration with TURN server via the +The synapse Matrix homeserver supports integration with TURN server via the [TURN server REST API](<https://tools.ietf.org/html/draft-uberti-behave-turn-rest-00>). This -allows the Home Server to generate credentials that are valid for use on the -TURN server through the use of a secret shared between the Home Server and the +allows the homeserver to generate credentials that are valid for use on the +TURN server through the use of a secret shared between the homeserver and the TURN server. The following sections describe how to install [coturn](<https://github.com/coturn/coturn>) (which implements the TURN REST API) and integrate it with synapse. @@ -165,18 +165,18 @@ This will install and start a systemd service called `coturn`. ## Synapse setup -Your home server configuration file needs the following extra keys: +Your homeserver configuration file needs the following extra keys: 1. "`turn_uris`": This needs to be a yaml list of public-facing URIs for your TURN server to be given out to your clients. Add separate entries for each transport your TURN server supports. 2. "`turn_shared_secret`": This is the secret shared between your - Home server and your TURN server, so you should set it to the same + homeserver and your TURN server, so you should set it to the same string you used in turnserver.conf. 3. "`turn_user_lifetime`": This is the amount of time credentials - generated by your Home Server are valid for (in milliseconds). + generated by your homeserver are valid for (in milliseconds). Shorter times offer less potential for abuse at the expense of - increased traffic between web clients and your home server to + increased traffic between web clients and your homeserver to refresh credentials. The TURN REST API specification recommends one day (86400000). 4. "`turn_allow_guests`": Whether to allow guest users to use the @@ -220,7 +220,7 @@ Here are a few things to try: anyone who has successfully set this up. * Check that you have opened your firewall to allow TCP and UDP traffic to the - TURN ports (normally 3478 and 5479). + TURN ports (normally 3478 and 5349). * Check that you have opened your firewall to allow UDP traffic to the UDP relay ports (49152-65535 by default). diff --git a/docs/usage/administration/admin_api/background_updates.md b/docs/usage/administration/admin_api/background_updates.md index b36d7fe398..9f6ac7d567 100644 --- a/docs/usage/administration/admin_api/background_updates.md +++ b/docs/usage/administration/admin_api/background_updates.md @@ -42,7 +42,6 @@ For each update: `average_items_per_ms` how many items are processed per millisecond based on an exponential average. - ## Enabled This API allow pausing background updates. @@ -82,3 +81,29 @@ The API returns the `enabled` param. ``` There is also a `GET` version which returns the `enabled` state. + + +## Run + +This API schedules a specific background update to run. The job starts immediately after calling the API. + + +The API is: + +``` +POST /_synapse/admin/v1/background_updates/start_job +``` + +with the following body: + +```json +{ + "job_name": "populate_stats_process_rooms" +} +``` + +The following JSON body parameters are available: + +- `job_name` - A string which job to run. Valid values are: + - `populate_stats_process_rooms` - Recalculate the stats for all rooms. + - `regenerate_directory` - Recalculate the [user directory](../../../user_directory.md) if it is stale or out of sync. diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md new file mode 100644 index 0000000000..8f9535f57b --- /dev/null +++ b/docs/usage/administration/admin_api/federation.md @@ -0,0 +1,114 @@ +# Federation API + +This API allows a server administrator to manage Synapse's federation with other homeservers. + +Note: This API is new, experimental and "subject to change". + +## List of destinations + +This API gets the current destination retry timing info for all remote servers. + +The list contains all the servers with which the server federates, +regardless of whether an error occurred or not. +If an error occurs, it may take up to 20 minutes for the error to be displayed here, +as a complete retry must have failed. + +The API is: + +A standard request with no filtering: + +``` +GET /_synapse/admin/v1/federation/destinations +``` + +A response body like the following is returned: + +```json +{ + "destinations":[ + { + "destination": "matrix.org", + "retry_last_ts": 1557332397936, + "retry_interval": 3000000, + "failure_ts": 1557329397936, + "last_successful_stream_ordering": null + } + ], + "total": 1 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint again +with `from` set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more destinations +to paginate through. + +**Parameters** + +The following query parameters are available: + +- `from` - Offset in the returned list. Defaults to `0`. +- `limit` - Maximum amount of destinations to return. Defaults to `100`. +- `order_by` - The method in which to sort the returned list of destinations. + Valid values are: + - `destination` - Destinations are ordered alphabetically by remote server name. + This is the default. + - `retry_last_ts` - Destinations are ordered by time of last retry attempt in ms. + - `retry_interval` - Destinations are ordered by how long until next retry in ms. + - `failure_ts` - Destinations are ordered by when the server started failing in ms. + - `last_successful_stream_ordering` - Destinations are ordered by the stream ordering + of the most recent successfully-sent PDU. +- `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. + +*Caution:* The database only has an index on the column `destination`. +This means that if a different sort order is used, +this can cause a large load on the database, especially for large environments. + +**Response** + +The following fields are returned in the JSON response body: + +- `destinations` - An array of objects, each containing information about a destination. + Destination objects contain the following fields: + - `destination` - string - Name of the remote server to federate. + - `retry_last_ts` - integer - The last time Synapse tried and failed to reach the + remote server, in ms. This is `0` if the last attempt to communicate with the + remote server was successful. + - `retry_interval` - integer - How long since the last time Synapse tried to reach + the remote server before trying again, in ms. This is `0` if no further retrying occuring. + - `failure_ts` - nullable integer - The first time Synapse tried and failed to reach the + remote server, in ms. This is `null` if communication with the remote server has never failed. + - `last_successful_stream_ordering` - nullable integer - The stream ordering of the most + recent successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation) + to this destination, or `null` if this information has not been tracked yet. +- `next_token`: string representing a positive integer - Indication for pagination. See above. +- `total` - integer - Total number of destinations. + +# Destination Details API + +This API gets the retry timing info for a specific remote server. + +The API is: + +``` +GET /_synapse/admin/v1/federation/destinations/<destination> +``` + +A response body like the following is returned: + +```json +{ + "destination": "matrix.org", + "retry_last_ts": 1557332397936, + "retry_interval": 3000000, + "failure_ts": 1557329397936, + "last_successful_stream_ordering": null +} +``` + +**Response** + +The response fields are the same like in the `destinations` array in +[List of destinations](#list-of-destinations) response. diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md new file mode 100644 index 0000000000..3dcad4bbef --- /dev/null +++ b/docs/usage/administration/admin_faq.md @@ -0,0 +1,103 @@ +## Admin FAQ + +How do I become a server admin? +--- +If your server already has an admin account you should use the user admin API to promote other accounts to become admins. See [User Admin API](../../admin_api/user_admin_api.md#Change-whether-a-user-is-a-server-administrator-or-not) + +If you don't have any admin accounts yet you won't be able to use the admin API so you'll have to edit the database manually. Manually editing the database is generally not recommended so once you have an admin account, use the admin APIs to make further changes. + +```sql +UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; +``` +What servers are my server talking to? +--- +Run this sql query on your db: +```sql +SELECT * FROM destinations; +``` + +What servers are currently participating in this room? +--- +Run this sql query on your db: +```sql +SELECT DISTINCT split_part(state_key, ':', 2) + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + WHERE room_id = '!cURbafjkfsMDVwdRDQ:matrix.org' AND membership = 'join'; +``` + +What users are registered on my server? +--- +```sql +SELECT NAME from users; +``` + +Manually resetting passwords: +--- +See https://github.com/matrix-org/synapse/blob/master/README.rst#password-reset + +I have a problem with my server. Can I just delete my database and start again? +--- +Deleting your database is unlikely to make anything better. + +It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated network: lots of other servers have information about your server. + +For example: other servers might think that you are in a room, your server will think that you are not, and you'll probably be unable to interact with that room in a sensible way ever again. + +In general, there are better solutions to any problem than dropping the database. Come and seek help in https://matrix.to/#/#synapse:matrix.org. + +There are two exceptions when it might be sensible to delete your database and start again: +* You have *never* joined any rooms which are federated with other servers. For instance, a local deployment which the outside world can't talk to. +* You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database. +(In both cases you probably also want to clear out the media_store.) + +I've stuffed up access to my room, how can I delete it to free up the alias? +--- +Using the following curl command: +``` +curl -H 'Authorization: Bearer <access-token>' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/<room-alias> +``` +`<access-token>` - can be obtained in riot by looking in the riot settings, down the bottom is: +Access Token:\<click to reveal\> + +`<room-alias>` - the room alias, eg. #my_room:matrix.org this possibly needs to be URL encoded also, for example %23my_room%3Amatrix.org + +How can I find the lines corresponding to a given HTTP request in my homeserver log? +--- + +Synapse tags each log line according to the HTTP request it is processing. When it finishes processing each request, it logs a line containing the words `Processed request: `. For example: + +``` +2019-02-14 22:35:08,196 - synapse.access.http.8008 - 302 - INFO - GET-37 - ::1 - 8008 - {@richvdh:localhost} Processed request: 0.173sec/0.001sec (0.002sec, 0.000sec) (0.027sec/0.026sec/2) 687B 200 "GET /_matrix/client/r0/sync HTTP/1.1" "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Safari/537.36" [0 dbevts]" +``` + +Here we can see that the request has been tagged with `GET-37`. (The tag depends on the method of the HTTP request, so might start with `GET-`, `PUT-`, `POST-`, `OPTIONS-` or `DELETE-`.) So to find all lines corresponding to this request, we can do: + +``` +grep 'GET-37' homeserver.log +``` + +If you want to paste that output into a github issue or matrix room, please remember to surround it with triple-backticks (```) to make it legible (see https://help.github.com/en/articles/basic-writing-and-formatting-syntax#quoting-code). + + +What do all those fields in the 'Processed' line mean? +--- +See [Request log format](request_log.md). + + +What are the biggest rooms on my server? +--- + +```sql +SELECT s.canonical_alias, g.room_id, count(*) AS num_rows +FROM + state_groups_state AS g, + room_stats_state AS s +WHERE g.room_id = s.room_id +GROUP BY s.canonical_alias, g.room_id +ORDER BY num_rows desc +LIMIT 10; +``` + +You can also use the [List Room API](../../admin_api/rooms.md#list-room-api) +and `order_by` `state_events`. diff --git a/docs/usage/administration/database_maintenance_tools.md b/docs/usage/administration/database_maintenance_tools.md new file mode 100644 index 0000000000..92b805d413 --- /dev/null +++ b/docs/usage/administration/database_maintenance_tools.md @@ -0,0 +1,18 @@ +This blog post by Victor Berger explains how to use many of the tools listed on this page: https://levans.fr/shrink-synapse-database.html + +# List of useful tools and scripts for maintenance Synapse database: + +## [Purge Remote Media API](../../admin_api/media_admin_api.md#purge-remote-media-api) +The purge remote media API allows server admins to purge old cached remote media. + +## [Purge Local Media API](../../admin_api/media_admin_api.md#delete-local-media) +This API deletes the *local* media from the disk of your own server. + +## [Purge History API](../../admin_api/purge_history_api.md) +The purge history API allows server admins to purge historic events from their database, reclaiming disk space. + +## [synapse-compress-state](https://github.com/matrix-org/rust-synapse-compress-state) +Tool for compressing (deduplicating) `state_groups_state` table. + +## [SQL for analyzing Synapse PostgreSQL database stats](useful_sql_for_admins.md) +Some easy SQL that reports useful stats about your Synapse database. \ No newline at end of file diff --git a/docs/usage/administration/state_groups.md b/docs/usage/administration/state_groups.md new file mode 100644 index 0000000000..f1dee7accf --- /dev/null +++ b/docs/usage/administration/state_groups.md @@ -0,0 +1,25 @@ +# How do State Groups work? + +As a general rule, I encourage people who want to understand the deepest darkest secrets of the database schema to drop by #synapse-dev:matrix.org and ask questions. + +However, one question that comes up frequently is that of how "state groups" work, and why the `state_groups_state` table gets so big, so here's an attempt to answer that question. + +We need to be able to relatively quickly calculate the state of a room at any point in that room's history. In other words, we need to know the state of the room at each event in that room. This is done as follows: + +A sequence of events where the state is the same are grouped together into a `state_group`; the mapping is recorded in `event_to_state_groups`. (Technically speaking, since a state event usually changes the state in the room, we are recording the state of the room *after* the given event id: which is to say, to a handwavey simplification, the first event in a state group is normally a state event, and others in the same state group are normally non-state-events.) + +`state_groups` records, for each state group, the id of the room that we're looking at, and also the id of the first event in that group. (I'm not sure if that event id is used much in practice.) + +Now, if we stored all the room state for each `state_group`, that would be a huge amount of data. Instead, for each state group, we normally store the difference between the state in that group and some other state group, and only occasionally (every 100 state changes or so) record the full state. + +So, most state groups have an entry in `state_group_edges` (don't ask me why it's not a column in `state_groups`) which records the previous state group in the room, and `state_groups_state` records the differences in state since that previous state group. + +A full state group just records the event id for each piece of state in the room at that point. + +## Known bugs with state groups + +There are various reasons that we can end up creating many more state groups than we need: see https://github.com/matrix-org/synapse/issues/3364 for more details. + +## Compression tool + +There is a tool at https://github.com/matrix-org/rust-synapse-compress-state which can compress the `state_groups_state` on a room by-room basis (essentially, it reduces the number of "full" state groups). This can result in dramatic reductions of the storage used. \ No newline at end of file diff --git a/docs/usage/administration/understanding_synapse_through_grafana_graphs.md b/docs/usage/administration/understanding_synapse_through_grafana_graphs.md new file mode 100644 index 0000000000..c365cc3923 --- /dev/null +++ b/docs/usage/administration/understanding_synapse_through_grafana_graphs.md @@ -0,0 +1,84 @@ +## Understanding Synapse through Grafana graphs + +It is possible to monitor much of the internal state of Synapse using [Prometheus](https://prometheus.io) +metrics and [Grafana](https://grafana.com/). +A guide for configuring Synapse to provide metrics is available [here](../../metrics-howto.md) +and information on setting up Grafana is [here](https://github.com/matrix-org/synapse/tree/master/contrib/grafana). +In this setup, Prometheus will periodically scrape the information Synapse provides and +store a record of it over time. Grafana is then used as an interface to query and +present this information through a series of pretty graphs. + +Once you have grafana set up, and assuming you're using [our grafana dashboard template](https://github.com/matrix-org/synapse/blob/master/contrib/grafana/synapse.json), look for the following graphs when debugging a slow/overloaded Synapse: + +## Message Event Send Time + +![image](https://user-images.githubusercontent.com/1342360/82239409-a1c8e900-9930-11ea-8081-e4614e0c63f4.png) + +This, along with the CPU and Memory graphs, is a good way to check the general health of your Synapse instance. It represents how long it takes for a user on your homeserver to send a message. + +## Transaction Count and Transaction Duration + +![image](https://user-images.githubusercontent.com/1342360/82239985-8d392080-9931-11ea-80d0-843ab2f22e1e.png) + +![image](https://user-images.githubusercontent.com/1342360/82240050-ab068580-9931-11ea-98f1-f94671cbac9a.png) + +These graphs show the database transactions that are occurring the most frequently, as well as those are that are taking the most amount of time to execute. + +![image](https://user-images.githubusercontent.com/1342360/82240192-e86b1300-9931-11ea-9aac-3e2c9bfa6fdc.png) + +In the first graph, we can see obvious spikes corresponding to lots of `get_user_by_id` transactions. This would be useful information to figure out which part of the Synapse codebase is potentially creating a heavy load on the system. However, be sure to cross-reference this with Transaction Duration, which states that `get_users_by_id` is actually a very quick database transaction and isn't causing as much load as others, like `persist_events`: + +![image](https://user-images.githubusercontent.com/1342360/82240467-62030100-9932-11ea-8db9-917f2d977fe1.png) + +Still, it's probably worth investigating why we're getting users from the database that often, and whether it's possible to reduce the amount of queries we make by adjusting our cache factor(s). + +The `persist_events` transaction is responsible for saving new room events to the Synapse database, so can often show a high transaction duration. + +## Federation + +The charts in the "Federation" section show information about incoming and outgoing federation requests. Federation data can be divided into two basic types: + +- PDU (Persistent Data Unit) - room events: messages, state events (join/leave), etc. These are permanently stored in the database. +- EDU (Ephemeral Data Unit) - other data, which need not be stored permanently, such as read receipts, typing notifications. + +The "Outgoing EDUs by type" chart shows the EDUs within outgoing federation requests by type: `m.device_list_update`, `m.direct_to_device`, `m.presence`, `m.receipt`, `m.typing`. + +If you see a large number of `m.presence` EDUs and are having trouble with too much CPU load, you can disable `presence` in the Synapse config. See also [#3971](https://github.com/matrix-org/synapse/issues/3971). + +## Caches + +![image](https://user-images.githubusercontent.com/1342360/82240572-8b239180-9932-11ea-96ff-6b5f0e57ebe5.png) + +![image](https://user-images.githubusercontent.com/1342360/82240666-b8703f80-9932-11ea-86af-9f663988d8da.png) + +This is quite a useful graph. It shows how many times Synapse attempts to retrieve a piece of data from a cache which the cache did not contain, thus resulting in a call to the database. We can see here that the `_get_joined_profile_from_event_id` cache is being requested a lot, and often the data we're after is not cached. + +Cross-referencing this with the Eviction Rate graph, which shows that entries are being evicted from `_get_joined_profile_from_event_id` quite often: + +![image](https://user-images.githubusercontent.com/1342360/82240766-de95df80-9932-11ea-8c15-5acfc57c48da.png) + +we should probably consider raising the size of that cache by raising its cache factor (a multiplier value for the size of an individual cache). Information on doing so is available [here](https://github.com/matrix-org/synapse/blob/ee421e524478c1ad8d43741c27379499c2f6135c/docs/sample_config.yaml#L608-L642) (note that the configuration of individual cache factors through the configuration file is available in Synapse v1.14.0+, whereas doing so through environment variables has been supported for a very long time). Note that this will increase Synapse's overall memory usage. + +## Forward Extremities + +![image](https://user-images.githubusercontent.com/1342360/82241440-13566680-9934-11ea-8b88-ba468db937ed.png) + +Forward extremities are the leaf events at the end of a DAG in a room, aka events that have no children. The more that exist in a room, the more [state resolution](https://spec.matrix.org/v1.1/server-server-api/#room-state-resolution) that Synapse needs to perform (hint: it's an expensive operation). While Synapse has code to prevent too many of these existing at one time in a room, bugs can sometimes make them crop up again. + +If a room has >10 forward extremities, it's worth checking which room is the culprit and potentially removing them using the SQL queries mentioned in [#1760](https://github.com/matrix-org/synapse/issues/1760). + +## Garbage Collection + +![image](https://user-images.githubusercontent.com/1342360/82241911-da6ac180-9934-11ea-9a0d-a311fe22acd0.png) + +Large spikes in garbage collection times (bigger than shown here, I'm talking in the +multiple seconds range), can cause lots of problems in Synapse performance. It's more an +indicator of problems, and a symptom of other problems though, so check other graphs for what might be causing it. + +## Final Thoughts + +If you're still having performance problems with your Synapse instance and you've +tried everything you can, it may just be a lack of system resources. Consider adding +more CPU and RAM, and make use of [worker mode](../../workers.md) +to make use of multiple CPU cores / multiple machines for your homeserver. + diff --git a/docs/usage/administration/useful_sql_for_admins.md b/docs/usage/administration/useful_sql_for_admins.md new file mode 100644 index 0000000000..d4aada3272 --- /dev/null +++ b/docs/usage/administration/useful_sql_for_admins.md @@ -0,0 +1,156 @@ +## Some useful SQL queries for Synapse Admins + +## Size of full matrix db +`SELECT pg_size_pretty( pg_database_size( 'matrix' ) );` +### Result example: +``` +pg_size_pretty +---------------- + 6420 MB +(1 row) +``` +## Show top 20 larger rooms by state events count +```sql +SELECT r.name, s.room_id, s.current_state_events + FROM room_stats_current s + LEFT JOIN room_stats_state r USING (room_id) + ORDER BY current_state_events DESC + LIMIT 20; +``` + +and by state_group_events count: +```sql +SELECT rss.name, s.room_id, count(s.room_id) FROM state_groups_state s +LEFT JOIN room_stats_state rss USING (room_id) +GROUP BY s.room_id, rss.name +ORDER BY count(s.room_id) DESC +LIMIT 20; +``` +plus same, but with join removed for performance reasons: +```sql +SELECT s.room_id, count(s.room_id) FROM state_groups_state s +GROUP BY s.room_id +ORDER BY count(s.room_id) DESC +LIMIT 20; +``` + +## Show top 20 larger tables by row count +```sql +SELECT relname, n_live_tup as rows + FROM pg_stat_user_tables + ORDER BY n_live_tup DESC + LIMIT 20; +``` +This query is quick, but may be very approximate, for exact number of rows use `SELECT COUNT(*) FROM <table_name>`. +### Result example: +``` +state_groups_state - 161687170 +event_auth - 8584785 +event_edges - 6995633 +event_json - 6585916 +event_reference_hashes - 6580990 +events - 6578879 +received_transactions - 5713989 +event_to_state_groups - 4873377 +stream_ordering_to_exterm - 4136285 +current_state_delta_stream - 3770972 +event_search - 3670521 +state_events - 2845082 +room_memberships - 2785854 +cache_invalidation_stream - 2448218 +state_groups - 1255467 +state_group_edges - 1229849 +current_state_events - 1222905 +users_in_public_rooms - 364059 +device_lists_stream - 326903 +user_directory_search - 316433 +``` + +## Show top 20 rooms by new events count in last 1 day: +```sql +SELECT e.room_id, r.name, COUNT(e.event_id) cnt FROM events e +LEFT JOIN room_stats_state r USING (room_id) +WHERE e.origin_server_ts >= DATE_PART('epoch', NOW() - INTERVAL '1 day') * 1000 GROUP BY e.room_id, r.name ORDER BY cnt DESC LIMIT 20; +``` + +## Show top 20 users on homeserver by sent events (messages) at last month: +```sql +SELECT user_id, SUM(total_events) + FROM user_stats_historical + WHERE TO_TIMESTAMP(end_ts/1000) AT TIME ZONE 'UTC' > date_trunc('day', now() - interval '1 month') + GROUP BY user_id + ORDER BY SUM(total_events) DESC + LIMIT 20; +``` + +## Show last 100 messages from needed user, with room names: +```sql +SELECT e.room_id, r.name, e.event_id, e.type, e.content, j.json FROM events e + LEFT JOIN event_json j USING (room_id) + LEFT JOIN room_stats_state r USING (room_id) + WHERE sender = '@LOGIN:example.com' + AND e.type = 'm.room.message' + ORDER BY stream_ordering DESC + LIMIT 100; +``` + +## Show top 20 larger tables by storage size +```sql +SELECT nspname || '.' || relname AS "relation", + pg_size_pretty(pg_total_relation_size(C.oid)) AS "total_size" + FROM pg_class C + LEFT JOIN pg_namespace N ON (N.oid = C.relnamespace) + WHERE nspname NOT IN ('pg_catalog', 'information_schema') + AND C.relkind <> 'i' + AND nspname !~ '^pg_toast' + ORDER BY pg_total_relation_size(C.oid) DESC + LIMIT 20; +``` +### Result example: +``` +public.state_groups_state - 27 GB +public.event_json - 9855 MB +public.events - 3675 MB +public.event_edges - 3404 MB +public.received_transactions - 2745 MB +public.event_reference_hashes - 1864 MB +public.event_auth - 1775 MB +public.stream_ordering_to_exterm - 1663 MB +public.event_search - 1370 MB +public.room_memberships - 1050 MB +public.event_to_state_groups - 948 MB +public.current_state_delta_stream - 711 MB +public.state_events - 611 MB +public.presence_stream - 530 MB +public.current_state_events - 525 MB +public.cache_invalidation_stream - 466 MB +public.receipts_linearized - 279 MB +public.state_groups - 160 MB +public.device_lists_remote_cache - 124 MB +public.state_group_edges - 122 MB +``` + +## Show rooms with names, sorted by events in this rooms +`echo "select event_json.room_id,room_stats_state.name from event_json,room_stats_state where room_stats_state.room_id=event_json.room_id" | psql synapse | sort | uniq -c | sort -n` +### Result example: +``` + 9459 !FPUfgzXYWTKgIrwKxW:matrix.org | This Week in Matrix + 9459 !FPUfgzXYWTKgIrwKxW:matrix.org | This Week in Matrix (TWIM) + 17799 !iDIOImbmXxwNngznsa:matrix.org | Linux in Russian + 18739 !GnEEPYXUhoaHbkFBNX:matrix.org | Riot Android + 23373 !QtykxKocfZaZOUrTwp:matrix.org | Matrix HQ + 39504 !gTQfWzbYncrtNrvEkB:matrix.org | ru.[matrix] + 43601 !iNmaIQExDMeqdITdHH:matrix.org | Riot + 43601 !iNmaIQExDMeqdITdHH:matrix.org | Riot Web/Desktop +``` + +## Lookup room state info by list of room_id +```sql +SELECT rss.room_id, rss.name, rss.canonical_alias, rss.topic, rss.encryption, rsc.joined_members, rsc.local_users_in_room, rss.join_rules +FROM room_stats_state rss +LEFT JOIN room_stats_current rsc USING (room_id) +WHERE room_id IN (WHERE room_id IN ( + '!OGEhHVWSdvArJzumhm:matrix.org', + '!YTvKGNlinIzlkMTVRl:matrix.org' +) +``` \ No newline at end of file diff --git a/docs/usage/configuration/user_authentication/refresh_tokens.md b/docs/usage/configuration/user_authentication/refresh_tokens.md new file mode 100644 index 0000000000..23b3cddae0 --- /dev/null +++ b/docs/usage/configuration/user_authentication/refresh_tokens.md @@ -0,0 +1,139 @@ +# Refresh Tokens + +Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible). + + +[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens + + +## Background and motivation + +Synapse users' sessions are identified by **access tokens**; access tokens are +issued to users on login. Each session gets a unique access token which identifies +it; the access token must be kept secret as it grants access to the user's account. + +Traditionally, these access tokens were eternally valid (at least until the user +explicitly chose to log out). + +In some cases, it may be desirable for these access tokens to expire so that the +potential damage caused by leaking an access token is reduced. +On the other hand, forcing a user to re-authenticate (log in again) often might +be too much of an inconvenience. + +**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst +still getting most of the benefits of short access token lifetimes. +Refresh tokens are also a concept present in OAuth 2 — further reading is available +[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5). + +When refresh tokens are in use, both an access token and a refresh token will be +issued to users on login. The access token will expire after a predetermined amount +of time, but otherwise works in the same way as before. When the access token is +close to expiring (or has expired), the user's client should present the homeserver +(Synapse) with the refresh token. + +The homeserver will then generate a new access token and refresh token for the user +and return them. The old refresh token is invalidated and can not be used again*. + +Finally, refresh tokens also make it possible for sessions to be logged out if they +are inactive for too long, before the session naturally ends; see the configuration +guide below. + + +*To prevent issues if clients lose connection half-way through refreshing a token, +the refresh token is only invalidated once the new access token has been used at +least once. For all intents and purposes, the above simplification is sufficient. + + +## Caveats + +There are some caveats: + +* If a third party gets both your access token and refresh token, they will be able to + continue to enjoy access to your session. + * This is still an improvement because you (the user) will notice when *your* + session expires and you're not able to use your refresh token. + That would be a giveaway that someone else has compromised your session. + You would be able to log in again and terminate that session. + Previously (with long-lived access tokens), a third party that has your access + token could go undetected for a very long time. +* Clients need to implement support for refresh tokens in order for them to be a + useful mechanism. + * It is up to homeserver administrators if they want to issue long-lived access + tokens to clients not implementing refresh tokens. + * For compatibility, it is likely that they should, at least until client support + is widespread. + * Users with clients that support refresh tokens will still benefit from the + added security; it's not possible to downgrade a session to using long-lived + access tokens so this effectively gives users the choice. + * In a closed environment where all users use known clients, this may not be + an issue as the homeserver administrator can know if the clients have refresh + token support. In that case, the non-refreshable access token lifetime + may be set to a short duration so that a similar level of security is provided. + + +## Configuration Guide + +The following configuration options, in the `registration` section, are related: + +* `session_lifetime`: maximum length of a session, even if it's refreshed. + In other words, the client must log in again after this time period. + In most cases, this can be unset (infinite) or set to a long time (years or months). +* `refreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients supporting refresh tokens. + This should be short; a good value might be 5 minutes (`5m`). +* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created + by clients which don't support refresh tokens. + Make this short if you want to effectively force use of refresh tokens. + Make this long if you don't want to inconvenience users of clients which don't + support refresh tokens (by forcing them to frequently re-authenticate using + login credentials). +* `refresh_token_lifetime`: lifetime of refresh tokens. + In other words, the client must refresh within this time period to maintain its session. + Unless you want to log inactive sessions out, it is often fine to use a long + value here or even leave it unset (infinite). + Beware that making it too short will inconvenience clients that do not connect + very often, including mobile clients and clients of infrequent users (by making + it more difficult for them to refresh in time, which may force them to need to + re-authenticate using login credentials). + +**Note:** All four options above only apply when tokens are created (by logging in or refreshing). +Changes to these settings do not apply retroactively. + + +### Using refresh token expiry to log out inactive sessions + +If you'd like to force sessions to be logged out upon inactivity, you can enable +refreshable access token expiry and refresh token expiry. + +This works because a client must refresh at least once within a period of +`refresh_token_lifetime` in order to maintain valid credentials to access the +account. + +(It's suggested that `refresh_token_lifetime` should be longer than +`refreshable_access_token_lifetime` and this section assumes that to be the case +for simplicity.) + +Note: this will only affect sessions using refresh tokens. You may wish to +set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed +by clients that do not support refresh tokens. + + +#### Choosing values that guarantee permitting some inactivity + +It may be desirable to permit some short periods of inactivity, for example to +accommodate brief outages in client connectivity. + +The following model aims to provide guidance for choosing `refresh_token_lifetime` +and `refreshable_access_token_lifetime` to satisfy requirements of the form: + +1. inactivity longer than `L` **MUST** cause the session to be logged out; and +2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out. + +This model makes the weakest assumption that all active clients will refresh as +needed to maintain an active access token, but no sooner. +*In reality, clients may refresh more often than this model assumes, but the +above requirements will still hold.* + +To satisfy the above model, +* `refresh_token_lifetime` should be set to `L`; and +* `refreshable_access_token_lifetime` should be set to `L - S`. diff --git a/docs/usage/configuration/user_authentication/single_sign_on/README.md b/docs/usage/configuration/user_authentication/single_sign_on/README.md new file mode 100644 index 0000000000..b94aad92cf --- /dev/null +++ b/docs/usage/configuration/user_authentication/single_sign_on/README.md @@ -0,0 +1,5 @@ +# Single Sign-On + +Synapse supports single sign-on through the SAML, Open ID Connect or CAS protocols. +LDAP and other login methods are supported through first and third-party password +auth provider modules. \ No newline at end of file diff --git a/docs/usage/configuration/user_authentication/single_sign_on/cas.md b/docs/usage/configuration/user_authentication/single_sign_on/cas.md new file mode 100644 index 0000000000..3bac1b29f0 --- /dev/null +++ b/docs/usage/configuration/user_authentication/single_sign_on/cas.md @@ -0,0 +1,8 @@ +# CAS + +Synapse supports authenticating users via the [Central Authentication +Service protocol](https://en.wikipedia.org/wiki/Central_Authentication_Service) +(CAS) natively. + +Please see the `cas_config` and `sso` sections of the [Synapse configuration +file](../../../configuration/homeserver_sample_config.md) for more details. \ No newline at end of file diff --git a/docs/usage/configuration/user_authentication/single_sign_on/saml.md b/docs/usage/configuration/user_authentication/single_sign_on/saml.md new file mode 100644 index 0000000000..2b6f052cc1 --- /dev/null +++ b/docs/usage/configuration/user_authentication/single_sign_on/saml.md @@ -0,0 +1,8 @@ +# SAML + +Synapse supports authenticating users via the [Security Assertion +Markup Language](https://en.wikipedia.org/wiki/Security_Assertion_Markup_Language) +(SAML) protocol natively. + +Please see the `saml2_config` and `sso` sections of the [Synapse configuration +file](../../../configuration/homeserver_sample_config.md) for more details. \ No newline at end of file diff --git a/docs/user_directory.md b/docs/user_directory.md index 07fe954891..c4794b04cf 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -6,9 +6,9 @@ on this particular server - i.e. ones which your account shares a room with, or who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after -DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL [here](https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/main/delta/53/user_dir_populate.sql) -and then restart synapse. This should then start a background task to +DB corruption) get stale or out of sync. If this happens, for now the +solution to fix it is to use the [admin API](usage/administration/admin_api/background_updates.md#run) +and execute the job `regenerate_directory`. This should then start a background task to flush the current tables and regenerate the directory. Data model diff --git a/docs/workers.md b/docs/workers.md index f88e2c1de3..fd83e2ddeb 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -182,10 +182,10 @@ This worker can handle API requests matching the following regular expressions: # Sync requests - ^/_matrix/client/(v2_alpha|r0)/sync$ - ^/_matrix/client/(api/v1|v2_alpha|r0)/events$ - ^/_matrix/client/(api/v1|r0)/initialSync$ - ^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$ + ^/_matrix/client/(v2_alpha|r0|v3)/sync$ + ^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$ + ^/_matrix/client/(api/v1|r0|v3)/initialSync$ + ^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$ # Federation requests ^/_matrix/federation/v1/event/ @@ -210,46 +210,46 @@ expressions: ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/ - ^/_matrix/federation/unstable/org.matrix.msc2946/hierarchy/ + ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/ # Inbound federation transaction request ^/_matrix/federation/v1/send/ # Client API requests - ^/_matrix/client/(api/v1|r0|unstable)/createRoom$ - ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/createRoom$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/publicRooms$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/joined_members$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ - ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$ + ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ - ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$ - ^/_matrix/client/(api/v1|r0|unstable)/devices$ - ^/_matrix/client/(api/v1|r0|unstable)/keys/query$ - ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$ ^/_matrix/client/versions$ - ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$ - ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ - ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ - ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/event/ - ^/_matrix/client/(api/v1|r0|unstable)/joined_rooms$ - ^/_matrix/client/(api/v1|r0|unstable)/search$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ # Registration/login requests - ^/_matrix/client/(api/v1|r0|unstable)/login$ - ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/login$ + ^/_matrix/client/(r0|v3|unstable)/register$ ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ # Event sending requests - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ - ^/_matrix/client/(api/v1|r0|unstable)/join/ - ^/_matrix/client/(api/v1|r0|unstable)/profile/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/redact + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/send + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ Additionally, the following REST endpoints can be handled for GET requests: @@ -261,14 +261,14 @@ room must be routed to the same instance. Additionally, care must be taken to ensure that the purge history admin API is not used while pagination requests for the room are in flight: - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/messages$ Additionally, the following endpoints should be included if Synapse is configured to use SSO (you only need to include the ones for whichever SSO provider you're using): # for all SSO providers - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect + ^/_matrix/client/(api/v1|r0|v3|unstable)/login/sso/redirect ^/_synapse/client/pick_idp$ ^/_synapse/client/pick_username ^/_synapse/client/new_user_consent$ @@ -281,7 +281,7 @@ using): ^/_synapse/client/saml2/authn_response$ # CAS requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/login/cas/ticket$ Ensure that all SSO logins go to a single process. For multiple workers not handling the SSO endpoints properly, see @@ -465,7 +465,7 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for Handles searches in the user directory. It can handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$ When using this worker you must also set `update_user_directory: False` in the shared configuration file to stop the main synapse running background @@ -477,12 +477,12 @@ Proxies some frequently-requested client endpoints to add caching and remove load from the main synapse. It can handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|unstable)/keys/upload + ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload If `use_presence` is False in the homeserver config, it can also handle REST endpoints matching the following regular expressions: - ^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status + ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/[^/]+/status This "stub" presence handler will pass through `GET` request but make the `PUT` effectively a no-op. diff --git a/mypy.ini b/mypy.ini index 1752b82bc5..1caf807e85 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,64 +23,30 @@ files = # https://docs.python.org/3/library/re.html#re.X exclude = (?x) ^( - |synapse/_scripts/register_new_matrix_user.py - |synapse/_scripts/review_recent_signups.py - |synapse/app/__init__.py - |synapse/app/_base.py - |synapse/app/admin_cmd.py - |synapse/app/appservice.py - |synapse/app/client_reader.py - |synapse/app/event_creator.py - |synapse/app/federation_reader.py - |synapse/app/federation_sender.py - |synapse/app/frontend_proxy.py - |synapse/app/generic_worker.py - |synapse/app/homeserver.py - |synapse/app/media_repository.py - |synapse/app/phone_stats_home.py - |synapse/app/pusher.py - |synapse/app/synchrotron.py - |synapse/app/user_dir.py |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/account_data.py |synapse/storage/databases/main/cache.py - |synapse/storage/databases/main/censor_events.py - |synapse/storage/databases/main/deviceinbox.py |synapse/storage/databases/main/devices.py - |synapse/storage/databases/main/directory.py |synapse/storage/databases/main/e2e_room_keys.py |synapse/storage/databases/main/end_to_end_keys.py |synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_push_actions.py |synapse/storage/databases/main/events_bg_updates.py - |synapse/storage/databases/main/events_forward_extremities.py - |synapse/storage/databases/main/events_worker.py - |synapse/storage/databases/main/filtering.py |synapse/storage/databases/main/group_server.py - |synapse/storage/databases/main/lock.py - |synapse/storage/databases/main/media_repository.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py - |synapse/storage/databases/main/openid.py |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/profile.py |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py - |synapse/storage/databases/main/rejections.py |synapse/storage/databases/main/room.py - |synapse/storage/databases/main/room_batch.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py - |synapse/storage/databases/main/signatures.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/state_deltas.py |synapse/storage/databases/main/stats.py - |synapse/storage/databases/main/tags.py |synapse/storage/databases/main/transactions.py |synapse/storage/databases/main/user_directory.py - |synapse/storage/databases/main/user_erasure_store.py |synapse/storage/schema/ |tests/api/test_auth.py @@ -120,9 +86,6 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/admin/test_admin.py - |tests/rest/admin/test_device.py - |tests/rest/admin/test_media.py - |tests/rest/admin/test_server_notice.py |tests/rest/admin/test_user.py |tests/rest/admin/test_username_available.py |tests/rest/client/test_account.py @@ -145,7 +108,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/state/test_v2.py |tests/storage/test_account_data.py - |tests/storage/test_appservice.py |tests/storage/test_background_update.py |tests/storage/test_base.py |tests/storage/test_client_ips.py @@ -158,7 +120,6 @@ exclude = (?x) |tests/test_server.py |tests/test_state.py |tests/test_terms_auth.py - |tests/test_visibility.py |tests/unittest.py |tests/util/caches/test_cached_call.py |tests/util/caches/test_deferred_cache.py @@ -181,128 +142,93 @@ exclude = (?x) [mypy-synapse.api.*] disallow_untyped_defs = True -[mypy-synapse.crypto.*] +[mypy-synapse.app.*] disallow_untyped_defs = True -[mypy-synapse.events.*] +[mypy-synapse.config._base] disallow_untyped_defs = True -[mypy-synapse.handlers.*] -disallow_untyped_defs = True - -[mypy-synapse.push.*] -disallow_untyped_defs = True - -[mypy-synapse.rest.*] -disallow_untyped_defs = True - -[mypy-synapse.server_notices.*] -disallow_untyped_defs = True - -[mypy-synapse.state.*] -disallow_untyped_defs = True - -[mypy-synapse.storage.databases.main.client_ips] -disallow_untyped_defs = True - -[mypy-synapse.storage.util.*] -disallow_untyped_defs = True - -[mypy-synapse.streams.*] -disallow_untyped_defs = True - -[mypy-synapse.util.batching_queue] -disallow_untyped_defs = True - -[mypy-synapse.util.caches.cached_call] -disallow_untyped_defs = True - -[mypy-synapse.util.caches.dictionary_cache] -disallow_untyped_defs = True - -[mypy-synapse.util.caches.lrucache] +[mypy-synapse.crypto.*] disallow_untyped_defs = True -[mypy-synapse.util.caches.response_cache] +[mypy-synapse.events.*] disallow_untyped_defs = True -[mypy-synapse.util.caches.stream_change_cache] +[mypy-synapse.federation.*] disallow_untyped_defs = True -[mypy-synapse.util.caches.ttl_cache] -disallow_untyped_defs = True +[mypy-synapse.federation.transport.client] +disallow_untyped_defs = False -[mypy-synapse.util.daemonize] +[mypy-synapse.handlers.*] disallow_untyped_defs = True -[mypy-synapse.util.file_consumer] +[mypy-synapse.metrics.*] disallow_untyped_defs = True -[mypy-synapse.util.frozenutils] +[mypy-synapse.module_api.*] disallow_untyped_defs = True -[mypy-synapse.util.hash] +[mypy-synapse.push.*] disallow_untyped_defs = True -[mypy-synapse.util.httpresourcetree] +[mypy-synapse.rest.*] disallow_untyped_defs = True -[mypy-synapse.util.iterutils] +[mypy-synapse.server_notices.*] disallow_untyped_defs = True -[mypy-synapse.util.linked_list] +[mypy-synapse.state.*] disallow_untyped_defs = True -[mypy-synapse.util.logcontext] +[mypy-synapse.storage.databases.main.client_ips] disallow_untyped_defs = True -[mypy-synapse.util.logformatter] +[mypy-synapse.storage.databases.main.directory] disallow_untyped_defs = True -[mypy-synapse.util.macaroons] +[mypy-synapse.storage.databases.main.events_worker] disallow_untyped_defs = True -[mypy-synapse.util.manhole] +[mypy-synapse.storage.databases.main.room_batch] disallow_untyped_defs = True -[mypy-synapse.util.module_loader] +[mypy-synapse.storage.databases.main.profile] disallow_untyped_defs = True -[mypy-synapse.util.msisdn] +[mypy-synapse.storage.databases.main.state_deltas] disallow_untyped_defs = True -[mypy-synapse.util.patch_inline_callbacks] +[mypy-synapse.storage.databases.main.user_erasure_store] disallow_untyped_defs = True -[mypy-synapse.util.ratelimitutils] +[mypy-synapse.storage.util.*] disallow_untyped_defs = True -[mypy-synapse.util.retryutils] +[mypy-synapse.streams.*] disallow_untyped_defs = True -[mypy-synapse.util.rlimit] +[mypy-synapse.util.*] disallow_untyped_defs = True -[mypy-synapse.util.stringutils] -disallow_untyped_defs = True +[mypy-synapse.util.caches.treecache] +disallow_untyped_defs = False -[mypy-synapse.util.templates] +[mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True -[mypy-synapse.util.threepids] +[mypy-tests.storage.test_profile] disallow_untyped_defs = True -[mypy-synapse.util.wheel_timer] +[mypy-tests.storage.test_user_directory] disallow_untyped_defs = True -[mypy-synapse.util.versionstring] +[mypy-tests.rest.client.test_directory] disallow_untyped_defs = True -[mypy-tests.handlers.test_user_directory] +[mypy-tests.federation.transport.test_client] disallow_untyped_defs = True -[mypy-tests.storage.test_user_directory] -disallow_untyped_defs = True ;; Dependencies without annotations ;; Before ignoring a module, check to see if type stubs are available. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 927c69a753..16be288ac1 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -24,7 +24,7 @@ set -e # Change to the repository root -cd "$(dirname "$0")/.." +cd "$(dirname $0)/.." # Check for a user-specified Complement checkout if [[ -z "$COMPLEMENT_DIR" ]]; then @@ -61,8 +61,8 @@ cd "$COMPLEMENT_DIR" EXTRA_COMPLEMENT_ARGS="" if [[ -n "$1" ]]; then # A test name regex has been set, supply it to Complement - EXTRA_COMPLEMENT_ARGS=(-run "$1") + EXTRA_COMPLEMENT_ARGS+="-run $1 " fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 "${EXTRA_COMPLEMENT_ARGS[@]}" ./tests +go test -v -tags synapse_blacklist,msc2403,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 6f76c08fcf..c72e19f61d 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -15,6 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. + +""" +Script for signing and sending federation requests. + +Some tips on doing the join dance with this: + + room_id=... + user_id=... + + # make_join + federation_client.py "/_matrix/federation/v1/make_join/$room_id/$user_id?ver=5" > make_join.json + + # sign + jq -M .event make_join.json | sign_json --sign-event-room-version=$(jq -r .room_version make_join.json) -o signed-join.json + + # send_join + federation_client.py -X PUT "/_matrix/federation/v2/send_join/$room_id/x" --body $(<signed-join.json) > send_join.json +""" + import argparse import base64 import json diff --git a/scripts-dev/sign_json b/scripts-dev/sign_json index 6ac55ef2f7..9459543106 100755 --- a/scripts-dev/sign_json +++ b/scripts-dev/sign_json @@ -22,6 +22,8 @@ import yaml from signedjson.key import read_signing_keys from signedjson.sign import sign_json +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.util import json_encoder @@ -68,6 +70,16 @@ Example usage: ), ) + parser.add_argument( + "--sign-event-room-version", + type=str, + help=( + "Sign the JSON as an event for the given room version, rather than raw JSON. " + "This means that we will add a 'hashes' object, and redact the event before " + "signing." + ), + ) + input_args = parser.add_mutually_exclusive_group() input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.") @@ -116,7 +128,17 @@ Example usage: print("Input json was not an object", file=sys.stderr) sys.exit(1) - sign_json(obj, args.server_name, keys[0]) + if args.sign_event_room_version: + room_version = KNOWN_ROOM_VERSIONS.get(args.sign_event_room_version) + if not room_version: + print( + f"Unknown room version {args.sign_event_room_version}", file=sys.stderr + ) + sys.exit(1) + add_hashes_and_signatures(room_version, obj, args.server_name, keys[0]) + else: + sign_json(obj, args.server_name, keys[0]) + for c in json_encoder.iterencode(obj): args.output.write(c) args.output.write("\n") diff --git a/setup.py b/setup.py index 5d602db240..2c6fb9aacb 100755 --- a/setup.py +++ b/setup.py @@ -110,6 +110,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [ "types-Pillow>=8.3.4", "types-pyOpenSSL>=20.0.7", "types-PyYAML>=5.4.10", + "types-requests>=2.26.0", "types-setuptools>=57.4.0", ] @@ -118,7 +119,9 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [ # Tests assume that all optional dependencies are installed. # # parameterized_class decorator was introduced in parameterized 0.7.0 -CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"] +# +# We use `mock` library as that backports `AsyncMock` to Python 3.6 +CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"] CONDITIONAL_REQUIREMENTS["dev"] = ( CONDITIONAL_REQUIREMENTS["lint"] @@ -149,6 +152,12 @@ setup( long_description=long_description, long_description_content_type="text/x-rst", python_requires="~=3.6", + entry_points={ + "console_scripts": [ + "synapse_homeserver = synapse.app.homeserver:main", + "synapse_worker = synapse.app.generic_worker:main", + ] + }, classifiers=[ "Development Status :: 5 - Production/Stable", "Topic :: Communications :: Chat", diff --git a/synapse/__init__.py b/synapse/__init__.py index 5ef34bce40..6369f18a53 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.46.0" +__version__ = "1.49.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index dae986c788..4ffe6a1ef3 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -1,5 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,22 +20,23 @@ import hashlib import hmac import logging import sys +from typing import Callable, Optional import requests as _requests import yaml def request_registration( - user, - password, - server_location, - shared_secret, - admin=False, - user_type=None, + user: str, + password: str, + server_location: str, + shared_secret: str, + admin: bool = False, + user_type: Optional[str] = None, requests=_requests, - _print=print, - exit=sys.exit, -): + _print: Callable[[str], None] = print, + exit: Callable[[int], None] = sys.exit, +) -> None: url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) @@ -65,13 +67,13 @@ def request_registration( mac.update(b"\x00") mac.update(user_type.encode("utf8")) - mac = mac.hexdigest() + hex_mac = mac.hexdigest() data = { "nonce": nonce, "username": user, "password": password, - "mac": mac, + "mac": hex_mac, "admin": admin, "user_type": user_type, } @@ -91,10 +93,17 @@ def request_registration( _print("Success!") -def register_new_user(user, password, server_location, shared_secret, admin, user_type): +def register_new_user( + user: str, + password: str, + server_location: str, + shared_secret: str, + admin: Optional[bool], + user_type: Optional[str], +) -> None: if not user: try: - default_user = getpass.getuser() + default_user: Optional[str] = getpass.getuser() except Exception: default_user = None @@ -123,8 +132,8 @@ def register_new_user(user, password, server_location, shared_secret, admin, use sys.exit(1) if admin is None: - admin = input("Make admin [no]: ") - if admin in ("y", "yes", "true"): + admin_inp = input("Make admin [no]: ") + if admin_inp in ("y", "yes", "true"): admin = True else: admin = False @@ -134,7 +143,7 @@ def register_new_user(user, password, server_location, shared_secret, admin, use ) -def main(): +def main() -> None: logging.captureWarnings(True) diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py index 8e66a38421..093af4327a 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py @@ -92,7 +92,7 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]: return user_infos -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "-c", @@ -142,7 +142,8 @@ def main(): engine = create_engine(database_config.config) with make_conn(database_config, engine, "review_recent_signups") as db_conn: - user_infos = get_recent_users(db_conn.cursor(), since_ms) + # This generates a type of Cursor, not LoggingTransaction. + user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type] for user_info in user_infos: if exclude_users_with_email and user_info.emails: diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a33ac34161..52c083a20b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +from typing_extensions import Final + # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -39,125 +41,125 @@ class Membership: """Represents the membership states of a user in a room.""" - INVITE = "invite" - JOIN = "join" - KNOCK = "knock" - LEAVE = "leave" - BAN = "ban" - LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) + INVITE: Final = "invite" + JOIN: Final = "join" + KNOCK: Final = "knock" + LEAVE: Final = "leave" + BAN: Final = "ban" + LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState: """Represents the presence state of a user.""" - OFFLINE = "offline" - UNAVAILABLE = "unavailable" - ONLINE = "online" - BUSY = "org.matrix.msc3026.busy" + OFFLINE: Final = "offline" + UNAVAILABLE: Final = "unavailable" + ONLINE: Final = "online" + BUSY: Final = "org.matrix.msc3026.busy" class JoinRules: - PUBLIC = "public" - KNOCK = "knock" - INVITE = "invite" - PRIVATE = "private" + PUBLIC: Final = "public" + KNOCK: Final = "knock" + INVITE: Final = "invite" + PRIVATE: Final = "private" # As defined for MSC3083. - RESTRICTED = "restricted" + RESTRICTED: Final = "restricted" class RestrictedJoinRuleTypes: """Understood types for the allow rules in restricted join rules.""" - ROOM_MEMBERSHIP = "m.room_membership" + ROOM_MEMBERSHIP: Final = "m.room_membership" class LoginType: - PASSWORD = "m.login.password" - EMAIL_IDENTITY = "m.login.email.identity" - MSISDN = "m.login.msisdn" - RECAPTCHA = "m.login.recaptcha" - TERMS = "m.login.terms" - SSO = "m.login.sso" - DUMMY = "m.login.dummy" - REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" + PASSWORD: Final = "m.login.password" + EMAIL_IDENTITY: Final = "m.login.email.identity" + MSISDN: Final = "m.login.msisdn" + RECAPTCHA: Final = "m.login.recaptcha" + TERMS: Final = "m.login.terms" + SSO: Final = "m.login.sso" + DUMMY: Final = "m.login.dummy" + REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by # an appservice to register a new user. -APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service" +APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service" class EventTypes: - Member = "m.room.member" - Create = "m.room.create" - Tombstone = "m.room.tombstone" - JoinRules = "m.room.join_rules" - PowerLevels = "m.room.power_levels" - Aliases = "m.room.aliases" - Redaction = "m.room.redaction" - ThirdPartyInvite = "m.room.third_party_invite" - RelatedGroups = "m.room.related_groups" - - RoomHistoryVisibility = "m.room.history_visibility" - CanonicalAlias = "m.room.canonical_alias" - Encrypted = "m.room.encrypted" - RoomAvatar = "m.room.avatar" - RoomEncryption = "m.room.encryption" - GuestAccess = "m.room.guest_access" + Member: Final = "m.room.member" + Create: Final = "m.room.create" + Tombstone: Final = "m.room.tombstone" + JoinRules: Final = "m.room.join_rules" + PowerLevels: Final = "m.room.power_levels" + Aliases: Final = "m.room.aliases" + Redaction: Final = "m.room.redaction" + ThirdPartyInvite: Final = "m.room.third_party_invite" + RelatedGroups: Final = "m.room.related_groups" + + RoomHistoryVisibility: Final = "m.room.history_visibility" + CanonicalAlias: Final = "m.room.canonical_alias" + Encrypted: Final = "m.room.encrypted" + RoomAvatar: Final = "m.room.avatar" + RoomEncryption: Final = "m.room.encryption" + GuestAccess: Final = "m.room.guest_access" # These are used for validation - Message = "m.room.message" - Topic = "m.room.topic" - Name = "m.room.name" + Message: Final = "m.room.message" + Topic: Final = "m.room.topic" + Name: Final = "m.room.name" - ServerACL = "m.room.server_acl" - Pinned = "m.room.pinned_events" + ServerACL: Final = "m.room.server_acl" + Pinned: Final = "m.room.pinned_events" - Retention = "m.room.retention" + Retention: Final = "m.room.retention" - Dummy = "org.matrix.dummy_event" + Dummy: Final = "org.matrix.dummy_event" - SpaceChild = "m.space.child" - SpaceParent = "m.space.parent" + SpaceChild: Final = "m.space.child" + SpaceParent: Final = "m.space.parent" - MSC2716_INSERTION = "org.matrix.msc2716.insertion" - MSC2716_BATCH = "org.matrix.msc2716.batch" - MSC2716_MARKER = "org.matrix.msc2716.marker" + MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion" + MSC2716_BATCH: Final = "org.matrix.msc2716.batch" + MSC2716_MARKER: Final = "org.matrix.msc2716.marker" class ToDeviceEventTypes: - RoomKeyRequest = "m.room_key_request" + RoomKeyRequest: Final = "m.room_key_request" class DeviceKeyAlgorithms: """Spec'd algorithms for the generation of per-device keys""" - ED25519 = "ed25519" - CURVE25519 = "curve25519" - SIGNED_CURVE25519 = "signed_curve25519" + ED25519: Final = "ed25519" + CURVE25519: Final = "curve25519" + SIGNED_CURVE25519: Final = "signed_curve25519" class EduTypes: - Presence = "m.presence" + Presence: Final = "m.presence" class RejectedReason: - AUTH_ERROR = "auth_error" + AUTH_ERROR: Final = "auth_error" class RoomCreationPreset: - PRIVATE_CHAT = "private_chat" - PUBLIC_CHAT = "public_chat" - TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + PRIVATE_CHAT: Final = "private_chat" + PUBLIC_CHAT: Final = "public_chat" + TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat" class ThirdPartyEntityKind: - USER = "user" - LOCATION = "location" + USER: Final = "user" + LOCATION: Final = "location" -ServerNoticeMsgType = "m.server_notice" -ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" +ServerNoticeMsgType: Final = "m.server_notice" +ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached" class UserTypes: @@ -165,91 +167,95 @@ class UserTypes: 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ - SUPPORT = "support" - BOT = "bot" - ALL_USER_TYPES = (SUPPORT, BOT) + SUPPORT: Final = "support" + BOT: Final = "bot" + ALL_USER_TYPES: Final = (SUPPORT, BOT) class RelationTypes: """The types of relations known to this server.""" - ANNOTATION = "m.annotation" - REPLACE = "m.replace" - REFERENCE = "m.reference" - THREAD = "io.element.thread" + ANNOTATION: Final = "m.annotation" + REPLACE: Final = "m.replace" + REFERENCE: Final = "m.reference" + THREAD: Final = "io.element.thread" class LimitBlockingTypes: """Reasons that a server may be blocked""" - MONTHLY_ACTIVE_USER = "monthly_active_user" - HS_DISABLED = "hs_disabled" + MONTHLY_ACTIVE_USER: Final = "monthly_active_user" + HS_DISABLED: Final = "hs_disabled" class EventContentFields: """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 - LABELS = "org.matrix.labels" + LABELS: Final = "org.matrix.labels" # Timestamp to delete the event after # cf https://github.com/matrix-org/matrix-doc/pull/2228 - SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after" # cf https://github.com/matrix-org/matrix-doc/pull/1772 - ROOM_TYPE = "type" + ROOM_TYPE: Final = "type" # Whether a room can federate. - FEDERATE = "m.federate" + FEDERATE: Final = "m.federate" # The creator of the room, as used in `m.room.create` events. - ROOM_CREATOR = "creator" + ROOM_CREATOR: Final = "creator" # Used in m.room.guest_access events. - GUEST_ACCESS = "guest_access" + GUEST_ACCESS: Final = "guest_access" # Used on normal messages to indicate they were historically imported after the fact - MSC2716_HISTORICAL = "org.matrix.msc2716.historical" + MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next batch ID should be in # order to connect to it - MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id" + MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id" # Used on "batch" events to indicate which insertion event it connects to - MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id" + MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id" # For "marker" events - MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" + MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion" # The authorising user for joining a restricted room. - AUTHORISING_USER = "join_authorised_via_users_server" + AUTHORISING_USER: Final = "join_authorised_via_users_server" class RoomTypes: """Understood values of the room_type field of m.room.create events.""" - SPACE = "m.space" + SPACE: Final = "m.space" class RoomEncryptionAlgorithms: - MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" - DEFAULT = MEGOLM_V1_AES_SHA2 + MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2" + DEFAULT: Final = MEGOLM_V1_AES_SHA2 class AccountDataTypes: - DIRECT = "m.direct" - IGNORED_USER_LIST = "m.ignored_user_list" + DIRECT: Final = "m.direct" + IGNORED_USER_LIST: Final = "m.ignored_user_list" class HistoryVisibility: - INVITED = "invited" - JOINED = "joined" - SHARED = "shared" - WORLD_READABLE = "world_readable" + INVITED: Final = "invited" + JOINED: Final = "joined" + SHARED: Final = "shared" + WORLD_READABLE: Final = "world_readable" class GuestAccess: - CAN_JOIN = "can_join" + CAN_JOIN: Final = "can_join" # anything that is not "can_join" is considered "forbidden", but for completeness: - FORBIDDEN = "forbidden" + FORBIDDEN: Final = "forbidden" + + +class ReceiptTypes: + READ: Final = "m.read" class ReadReceiptEventFields: - MSC2285_HIDDEN = "org.matrix.msc2285.hidden" + MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 4486b3bc7d..f9f9467dc1 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -30,7 +30,8 @@ FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" -MEDIA_PREFIX = "/_matrix/media/r0" +MEDIA_R0_PREFIX = "/_matrix/media/r0" +MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index f9940491e8..ee51480a9e 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import sys +from typing import Container from synapse import python_dependencies # noqa: E402 @@ -27,7 +28,9 @@ except python_dependencies.DependencyException as e: sys.exit(1) -def check_bind_error(e, address, bind_addresses): +def check_bind_error( + e: Exception, address: str, bind_addresses: Container[str] +) -> None: """ This method checks an exception occurred while binding on 0.0.0.0. If :: is specified in the bind addresses a warning is shown. @@ -38,9 +41,9 @@ def check_bind_error(e, address, bind_addresses): When binding on 0.0.0.0 after :: this can safely be ignored. Args: - e (Exception): Exception that was caught. - address (str): Address on which binding was attempted. - bind_addresses (list): Addresses on which the service listens. + e: Exception that was caught. + address: Address on which binding was attempted. + bind_addresses: Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: logger.warning( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index f2c1028b5d..5fc59c1be1 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,13 +22,28 @@ import socket import sys import traceback import warnings -from typing import TYPE_CHECKING, Awaitable, Callable, Iterable +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + NoReturn, + Optional, + Tuple, + cast, +) from cryptography.utils import CryptographyDeprecationWarning -from typing_extensions import NoReturn import twisted -from twisted.internet import defer, error, reactor +from twisted.internet import defer, error, reactor as _reactor +from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP +from twisted.internet.protocol import ServerFactory +from twisted.internet.tcp import Port from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool @@ -48,6 +63,7 @@ from synapse.logging.context import PreserveLoggingContext from synapse.metrics import register_threadpool from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats +from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.gai_resolver import GAIResolver @@ -57,33 +73,44 @@ from synapse.util.versionstring import get_version_string if TYPE_CHECKING: from synapse.server import HomeServer +# Twisted injects the global reactor to make it easier to import, this confuses +# mypy which thinks it is a module. Tell it that it a more proper type. +reactor = cast(ISynapseReactor, _reactor) + + logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict -_sighup_callbacks = [] +_sighup_callbacks: List[ + Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] +] = [] -def register_sighup(func, *args, **kwargs): +def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None: """ Register a function to be called when a SIGHUP occurs. Args: - func (function): Function to be called when sent a SIGHUP signal. + func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ _sighup_callbacks.append((func, args, kwargs)) -def start_worker_reactor(appname, config, run_command=reactor.run): +def start_worker_reactor( + appname: str, + config: HomeServerConfig, + run_command: Callable[[], None] = reactor.run, +) -> None: """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor. Pulls configuration from the 'worker' settings in 'config'. Args: - appname (str): application name which will be sent to syslog - config (synapse.config.Config): config object - run_command (Callable[]): callable that actually runs the reactor + appname: application name which will be sent to syslog + config: config object + run_command: callable that actually runs the reactor """ logger = logging.getLogger(config.worker.worker_app) @@ -101,32 +128,32 @@ def start_worker_reactor(appname, config, run_command=reactor.run): def start_reactor( - appname, - soft_file_limit, - gc_thresholds, - pid_file, - daemonize, - print_pidfile, - logger, - run_command=reactor.run, -): + appname: str, + soft_file_limit: int, + gc_thresholds: Optional[Tuple[int, int, int]], + pid_file: str, + daemonize: bool, + print_pidfile: bool, + logger: logging.Logger, + run_command: Callable[[], None] = reactor.run, +) -> None: """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor Args: - appname (str): application name which will be sent to syslog - soft_file_limit (int): + appname: application name which will be sent to syslog + soft_file_limit: gc_thresholds: - pid_file (str): name of pid file to write to if daemonize is True - daemonize (bool): true to run the reactor in a background process - print_pidfile (bool): whether to print the pid file, if daemonize is True - logger (logging.Logger): logger instance to pass to Daemonize - run_command (Callable[]): callable that actually runs the reactor + pid_file: name of pid file to write to if daemonize is True + daemonize: true to run the reactor in a background process + print_pidfile: whether to print the pid file, if daemonize is True + logger: logger instance to pass to Daemonize + run_command: callable that actually runs the reactor """ - def run(): + def run() -> None: logger.info("Running") setup_jemalloc_stats() change_resource_limit(soft_file_limit) @@ -185,7 +212,7 @@ def redirect_stdio_to_logs() -> None: print("Redirected stdout/stderr to logs") -def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: +def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None: """Register a callback with the reactor, to be called once it is running This can be used to initialise parts of the system which require an asynchronous @@ -195,7 +222,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: will exit. """ - async def wrapper(): + async def wrapper() -> None: try: await cb(*args, **kwargs) except Exception: @@ -224,7 +251,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses, port): +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ @@ -236,11 +263,11 @@ def listen_metrics(bind_addresses, port): def listen_manhole( - bind_addresses: Iterable[str], + bind_addresses: Collection[str], port: int, manhole_settings: ManholeConfig, manhole_globals: dict, -): +) -> None: # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -259,12 +286,18 @@ def listen_manhole( ) -def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): +def listen_tcp( + bind_addresses: Collection[str], + port: int, + factory: ServerFactory, + reactor: IReactorTCP = reactor, + backlog: int = 50, +) -> List[Port]: """ Create a TCP socket for a port and several addresses Returns: - list[twisted.internet.tcp.Port]: listening for TCP connections + list of twisted.internet.tcp.Port listening for TCP connections """ r = [] for address in bind_addresses: @@ -273,12 +306,19 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) - return r + # IReactorTCP returns an object implementing IListeningPort from listenTCP, + # but we know it will be a Port instance. + return r # type: ignore[return-value] def listen_ssl( - bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50 -): + bind_addresses: Collection[str], + port: int, + factory: ServerFactory, + context_factory: IOpenSSLContextFactory, + reactor: IReactorSSL = reactor, + backlog: int = 50, +) -> List[Port]: """ Create an TLS-over-TCP socket for a port and several addresses @@ -294,10 +334,13 @@ def listen_ssl( except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) - return r + # IReactorSSL incorrectly declares that an int is returned from listenSSL, + # it actually returns an object implementing IListeningPort, but we know it + # will be a Port instance. + return r # type: ignore[return-value] -def refresh_certificate(hs: "HomeServer"): +def refresh_certificate(hs: "HomeServer") -> None: """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. @@ -329,7 +372,7 @@ def refresh_certificate(hs: "HomeServer"): logger.info("Context factories updated.") -async def start(hs: "HomeServer"): +async def start(hs: "HomeServer") -> None: """ Start a Synapse server or worker. @@ -360,7 +403,7 @@ async def start(hs: "HomeServer"): if hasattr(signal, "SIGHUP"): @wrap_as_background_process("sighup") - def handle_sighup(*args, **kwargs): + async def handle_sighup(*args: Any, **kwargs: Any) -> None: # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. sdnotify(b"RELOADING=1") @@ -373,7 +416,7 @@ async def start(hs: "HomeServer"): # We defer running the sighup handlers until next reactor tick. This # is so that we're in a sane state, e.g. flushing the logs may fail # if the sighup happens in the middle of writing a log entry. - def run_sighup(*args, **kwargs): + def run_sighup(*args: Any, **kwargs: Any) -> None: # `callFromThread` should be "signal safe" as well as thread # safe. reactor.callFromThread(handle_sighup, *args, **kwargs) @@ -436,12 +479,8 @@ async def start(hs: "HomeServer"): atexit.register(gc.freeze) -def setup_sentry(hs: "HomeServer"): - """Enable sentry integration, if enabled in configuration - - Args: - hs - """ +def setup_sentry(hs: "HomeServer") -> None: + """Enable sentry integration, if enabled in configuration""" if not hs.config.metrics.sentry_enabled: return @@ -466,7 +505,7 @@ def setup_sentry(hs: "HomeServer"): scope.set_tag("worker_name", name) -def setup_sdnotify(hs: "HomeServer"): +def setup_sdnotify(hs: "HomeServer") -> None: """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if @@ -481,7 +520,7 @@ def setup_sdnotify(hs: "HomeServer"): sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET") -def sdnotify(state): +def sdnotify(state: bytes) -> None: """ Send a notification to systemd, if the NOTIFY_SOCKET env var is set. @@ -490,7 +529,7 @@ def sdnotify(state): package which many OSes don't include as a matter of principle. Args: - state (bytes): notification to send + state: notification to send """ if not isinstance(state, bytes): raise TypeError("sdnotify should be called with a bytes") diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index ad20b1d6aa..42238f7f28 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,6 +17,7 @@ import logging import os import sys import tempfile +from typing import List, Optional from twisted.internet import defer, task @@ -25,6 +26,7 @@ from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.events import EventBase from synapse.handlers.admin import ExfiltrationWriter from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore @@ -40,6 +42,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.server import HomeServer from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.types import StateMap from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -65,16 +68,11 @@ class AdminCmdSlavedStore( class AdminCmdServer(HomeServer): - DATASTORE_CLASS = AdminCmdSlavedStore + DATASTORE_CLASS = AdminCmdSlavedStore # type: ignore -async def export_data_command(hs: HomeServer, args): - """Export data for a user. - - Args: - hs - args (argparse.Namespace) - """ +async def export_data_command(hs: HomeServer, args: argparse.Namespace) -> None: + """Export data for a user.""" user_id = args.user_id directory = args.output_directory @@ -92,12 +90,12 @@ class FileExfiltrationWriter(ExfiltrationWriter): Note: This writes to disk on the main reactor thread. Args: - user_id (str): The user whose data is being exfiltrated. - directory (str|None): The directory to write the data to, if None then - will write to a temporary directory. + user_id: The user whose data is being exfiltrated. + directory: The directory to write the data to, if None then will write + to a temporary directory. """ - def __init__(self, user_id, directory=None): + def __init__(self, user_id: str, directory: Optional[str] = None): self.user_id = user_id if directory: @@ -111,7 +109,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): if list(os.listdir(self.base_directory)): raise Exception("Directory must be empty") - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[EventBase]) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) os.makedirs(room_directory, exist_ok=True) events_file = os.path.join(room_directory, "events") @@ -120,7 +118,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in events: print(json.dumps(event.get_pdu_json()), file=f) - def write_state(self, room_id, event_id, state): + def write_state( + self, room_id: str, event_id: str, state: StateMap[EventBase] + ) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) state_directory = os.path.join(room_directory, "state") os.makedirs(state_directory, exist_ok=True) @@ -131,7 +131,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event.get_pdu_json()), file=f) - def write_invite(self, room_id, event, state): + def write_invite( + self, room_id: str, event: EventBase, state: StateMap[EventBase] + ) -> None: self.write_events(room_id, [event]) # We write the invite state somewhere else as they aren't full events @@ -145,7 +147,9 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) - def write_knock(self, room_id, event, state): + def write_knock( + self, room_id: str, event: EventBase, state: StateMap[EventBase] + ) -> None: self.write_events(room_id, [event]) # We write the knock state somewhere else as they aren't full events @@ -159,11 +163,11 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): print(json.dumps(event), file=f) - def finished(self): + def finished(self) -> str: return self.base_directory -def start(config_options): +def start(config_options: List[str]) -> None: parser = argparse.ArgumentParser(description="Synapse Admin Command") HomeServerConfig.add_arguments_to_parser(parser) @@ -231,7 +235,7 @@ def start(config_options): # We also make sure that `_base.start` gets run before we actually run the # command. - async def run(): + async def run() -> None: with LoggingContext("command"): await _base.start(ss) await args.func(ss, args) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 218826741e..e256de2003 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -14,11 +14,10 @@ # limitations under the License. import logging import sys -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple from twisted.internet import address -from twisted.web.resource import IResource -from twisted.web.server import Request +from twisted.web.resource import Resource import synapse import synapse.events @@ -27,7 +26,8 @@ from synapse.api.urls import ( CLIENT_API_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, - MEDIA_PREFIX, + MEDIA_R0_PREFIX, + MEDIA_V3_PREFIX, SERVER_KEY_V2_PREFIX, ) from synapse.app import _base @@ -44,7 +44,7 @@ from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseSite +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -113,12 +113,14 @@ from synapse.storage.databases.main.monthly_active_users import ( ) from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.room import RoomWorkerStore +from synapse.storage.databases.main.room_batch import RoomBatchStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore +from synapse.types import JsonDict from synapse.util.httpresourcetree import create_resource_tree from synapse.util.versionstring import get_version_string @@ -143,7 +145,9 @@ class KeyUploadServlet(RestServlet): self.http_client = hs.get_simple_http_client() self.main_uri = hs.config.worker.worker_main_http_uri - async def on_POST(self, request: Request, device_id: Optional[str]): + async def on_POST( + self, request: SynapseRequest, device_id: Optional[str] + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -187,9 +191,8 @@ class KeyUploadServlet(RestServlet): # If the header exists, add to the comma-separated list of the first # instance of the header. Otherwise, generate a new header. if x_forwarded_for: - x_forwarded_for = [ - x_forwarded_for[0] + b", " + previous_host - ] + x_forwarded_for[1:] + x_forwarded_for = [x_forwarded_for[0] + b", " + previous_host] + x_forwarded_for.extend(x_forwarded_for[1:]) else: x_forwarded_for = [previous_host] headers[b"X-Forwarded-For"] = x_forwarded_for @@ -238,6 +241,7 @@ class GenericWorkerSlavedStore( SlavedEventStore, SlavedKeyStore, RoomWorkerStore, + RoomBatchStore, DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, @@ -253,13 +257,16 @@ class GenericWorkerSlavedStore( SessionStore, BaseSlavedStore, ): - pass + # Properties that multiple storage classes define. Tell mypy what the + # expected type is. + server_name: str + config: HomeServerConfig class GenericWorkerServer(HomeServer): - DATASTORE_CLASS = GenericWorkerSlavedStore + DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore - def _listen_http(self, listener_config: ListenerConfig): + def _listen_http(self, listener_config: ListenerConfig) -> None: port = listener_config.port bind_addresses = listener_config.bind_addresses @@ -267,10 +274,10 @@ class GenericWorkerServer(HomeServer): site_tag = listener_config.http_options.tag if site_tag is None: - site_tag = port + site_tag = str(port) # We always include a health resource. - resources: Dict[str, IResource] = {"/health": HealthResource()} + resources: Dict[str, Resource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -334,7 +341,8 @@ class GenericWorkerServer(HomeServer): resources.update( { - MEDIA_PREFIX: media_repo, + MEDIA_R0_PREFIX: media_repo, + MEDIA_V3_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, "/_synapse/admin": admin_resource, } @@ -386,7 +394,7 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) - def start_listening(self): + def start_listening(self) -> None: for listener in self.config.worker.worker_listeners: if listener.type == "http": self._listen_http(listener) @@ -411,7 +419,7 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) -def start(config_options): +def start(config_options: List[str]) -> None: try: config = HomeServerConfig.load_config("Synapse worker", config_options) except ConfigError as e: @@ -497,6 +505,10 @@ def start(config_options): _base.start_worker_reactor("synapse-generic-worker", config) -if __name__ == "__main__": +def main() -> None: with LoggingContext("main"): start(sys.argv[1:]) + + +if __name__ == "__main__": + main() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 336c279a44..dd76e07321 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,10 +16,10 @@ import logging import os import sys -from typing import Iterator +from typing import Dict, Iterable, Iterator, List -from twisted.internet import reactor -from twisted.web.resource import EncodingResourceWrapper, IResource +from twisted.internet.tcp import Port +from twisted.web.resource import EncodingResourceWrapper, Resource from twisted.web.server import GzipEncoderFactory from twisted.web.static import File @@ -29,7 +29,8 @@ from synapse import events from synapse.api.urls import ( FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, - MEDIA_PREFIX, + MEDIA_R0_PREFIX, + MEDIA_V3_PREFIX, SERVER_KEY_V2_PREFIX, STATIC_PREFIX, WEB_CLIENT_PREFIX, @@ -76,23 +77,27 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.homeserver") -def gz_wrap(r): +def gz_wrap(r: Resource) -> Resource: return EncodingResourceWrapper(r, [GzipEncoderFactory()]) class SynapseHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore - def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig): + def _listener_http( + self, config: HomeServerConfig, listener_config: ListenerConfig + ) -> Iterable[Port]: port = listener_config.port bind_addresses = listener_config.bind_addresses tls = listener_config.tls + # Must exist since this is an HTTP listener. + assert listener_config.http_options is not None site_tag = listener_config.http_options.tag if site_tag is None: site_tag = str(port) # We always include a health resource. - resources = {"/health": HealthResource()} + resources: Dict[str, Resource] = {"/health": HealthResource()} for res in listener_config.http_options.resources: for name in res.names: @@ -111,7 +116,7 @@ class SynapseHomeServer(HomeServer): ("listeners", site_tag, "additional_resources", "<%s>" % (path,)), ) handler = handler_cls(config, module_api) - if IResource.providedBy(handler): + if isinstance(handler, Resource): resource = handler elif hasattr(handler, "handle_request"): resource = AdditionalResource(self, handler.handle_request) @@ -128,7 +133,7 @@ class SynapseHomeServer(HomeServer): # try to find something useful to redirect '/' to if WEB_CLIENT_PREFIX in resources: - root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) + root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) elif STATIC_PREFIX in resources: root_resource = RootOptionsRedirectResource(STATIC_PREFIX) else: @@ -145,6 +150,8 @@ class SynapseHomeServer(HomeServer): ) if tls: + # refresh_certificate should have been called before this. + assert self.tls_server_context_factory is not None ports = listen_ssl( bind_addresses, port, @@ -165,20 +172,21 @@ class SynapseHomeServer(HomeServer): return ports - def _configure_named_resource(self, name, compress=False): + def _configure_named_resource( + self, name: str, compress: bool = False + ) -> Dict[str, Resource]: """Build a resource map for a named resource Args: - name (str): named resource: one of "client", "federation", etc - compress (bool): whether to enable gzip compression for this - resource + name: named resource: one of "client", "federation", etc + compress: whether to enable gzip compression for this resource Returns: - dict[str, Resource]: map from path to HTTP resource + map from path to HTTP resource """ - resources = {} + resources: Dict[str, Resource] = {} if name == "client": - client_resource = ClientRestResource(self) + client_resource: Resource = ClientRestResource(self) if compress: client_resource = gz_wrap(client_resource) @@ -186,6 +194,8 @@ class SynapseHomeServer(HomeServer): { "/_matrix/client/api/v1": client_resource, "/_matrix/client/r0": client_resource, + "/_matrix/client/v1": client_resource, + "/_matrix/client/v3": client_resource, "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, "/_matrix/client/versions": client_resource, @@ -207,7 +217,7 @@ class SynapseHomeServer(HomeServer): if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource - consent_resource = ConsentResource(self) + consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) resources.update({"/_matrix/consent": consent_resource}) @@ -237,7 +247,11 @@ class SynapseHomeServer(HomeServer): if self.config.server.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( - {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} + { + MEDIA_R0_PREFIX: media_repo, + MEDIA_V3_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + } ) elif name == "media": raise ConfigError( @@ -277,7 +291,7 @@ class SynapseHomeServer(HomeServer): return resources - def start_listening(self): + def start_listening(self) -> None: if self.config.redis.redis_enabled: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client @@ -303,7 +317,9 @@ class SynapseHomeServer(HomeServer): ReplicationStreamProtocolFactory(self), ) for s in services: - reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) + self.get_reactor().addSystemEventTrigger( + "before", "shutdown", s.stopListening + ) elif listener.type == "metrics": if not self.config.metrics.enable_metrics: logger.warning( @@ -318,14 +334,13 @@ class SynapseHomeServer(HomeServer): logger.warning("Unrecognized listener type: %s", listener.type) -def setup(config_options): +def setup(config_options: List[str]) -> SynapseHomeServer: """ Args: - config_options_options: The options passed to Synapse. Usually - `sys.argv[1:]`. + config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`. Returns: - HomeServer + A homeserver instance. """ try: config = HomeServerConfig.load_or_generate_config( @@ -343,6 +358,13 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) + if config.worker.worker_app: + raise ConfigError( + "You have specified `worker_app` in the config but are attempting to start a non-worker " + "instance. Please use `python -m synapse.app.generic_worker` instead (or remove the option if this is the main process)." + ) + sys.exit(1) + events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage @@ -364,7 +386,7 @@ def setup(config_options): except Exception as e: handle_startup_exception(e) - async def start(): + async def start() -> None: # Load the OIDC provider metadatas, if OIDC is enabled. if hs.config.oidc.oidc_enabled: oidc = hs.get_oidc_handler() @@ -404,39 +426,15 @@ def format_config_error(e: ConfigError) -> Iterator[str]: yield ":\n %s" % (e.msg,) - e = e.__cause__ + parent_e = e.__cause__ indent = 1 - while e: + while parent_e: indent += 1 - yield ":\n%s%s" % (" " * indent, str(e)) - e = e.__cause__ - - -def run(hs: HomeServer): - PROFILE_SYNAPSE = False - if PROFILE_SYNAPSE: - - def profile(func): - from cProfile import Profile - from threading import current_thread - - def profiled(*args, **kargs): - profile = Profile() - profile.enable() - func(*args, **kargs) - profile.disable() - ident = current_thread().ident - profile.dump_stats( - "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident) - ) - - return profiled - - from twisted.python.threadpool import ThreadPool + yield ":\n%s%s" % (" " * indent, str(parent_e)) + parent_e = parent_e.__cause__ - ThreadPool._worker = profile(ThreadPool._worker) - reactor.run = profile(reactor.run) +def run(hs: HomeServer) -> None: _base.start_reactor( "synapse-homeserver", soft_file_limit=hs.config.server.soft_file_limit, @@ -448,7 +446,7 @@ def run(hs: HomeServer): ) -def main(): +def main() -> None: with LoggingContext("main"): # check base requirements check_requirements() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 126450e17a..899dba5c3d 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,11 +15,12 @@ import logging import math import resource import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Sized, Tuple from prometheus_client import Gauge from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -28,7 +29,7 @@ logger = logging.getLogger("synapse.app.homeserver") # Contains the list of processes we will be monitoring # currently either 0 or 1 -_stats_process = [] +_stats_process: List[Tuple[int, "resource.struct_rusage"]] = [] # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") @@ -45,9 +46,15 @@ registered_reserved_users_mau_gauge = Gauge( @wrap_as_background_process("phone_stats_home") -async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): +async def phone_stats_home( + hs: "HomeServer", + stats: JsonDict, + stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, +) -> None: logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) + # Ensure the homeserver has started. + assert hs.start_time is not None uptime = int(now - hs.start_time) if uptime < 0: uptime = 0 @@ -146,15 +153,15 @@ async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process logger.warning("Error reporting stats: %s", e) -def start_phone_stats_home(hs: "HomeServer"): +def start_phone_stats_home(hs: "HomeServer") -> None: """ Start the background tasks which report phone home stats. """ clock = hs.get_clock() - stats = {} + stats: JsonDict = {} - def performance_stats_init(): + def performance_stats_init() -> None: _stats_process.clear() _stats_process.append( (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) @@ -170,10 +177,10 @@ def start_phone_stats_home(hs: "HomeServer"): hs.get_datastore().reap_monthly_active_users() @wrap_as_background_process("generate_monthly_active_users") - async def generate_monthly_active_users(): + async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service = {} - reserved_users = () + reserved_users: Sized = () store = hs.get_datastore() if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 6504c6bd3f..f9d3bd337d 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import re +from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes @@ -27,7 +28,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ApplicationServiceState: +class ApplicationServiceState(Enum): DOWN = "down" UP = "up" diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d08f6bbd7f..f51b636417 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -231,13 +231,32 @@ class ApplicationServiceApi(SimpleHttpClient): json_body=body, args={"access_token": service.hs_token}, ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "push_bulk to %s succeeded! events=%s", + uri, + [event.get("event_id") for event in events], + ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) return True except CodeMessageException as e: - logger.warning("push_bulk to %s received %s", uri, e.code) + logger.warning( + "push_bulk to %s received code=%s msg=%s", + uri, + e.code, + e.msg, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) except Exception as ex: - logger.warning("push_bulk to %s threw exception %s", uri, ex) + logger.warning( + "push_bulk to %s threw exception(%s) %s args=%s", + uri, + type(ex).__name__, + ex, + ex.args, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) failed_transactions_counter.labels(service.id).inc() return False diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index c555f5f914..b2a7a89a35 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +from typing import List from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig -def main(args): +def main(args: List[str]) -> None: action = args[1] if len(args) > 1 and args[1] == "read" else None # If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]` # will be the key to read. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 7c4428a138..1265738dc1 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -20,7 +20,18 @@ import os from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Iterable, List, MutableMapping, Optional, Union +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import attr import jinja2 @@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\ """ -def path_exists(file_path): +def path_exists(file_path: str) -> bool: """Check if a file exists Unlike os.path.exists, this throws an exception if there is an error @@ -86,7 +97,7 @@ def path_exists(file_path): the parent dir). Returns: - bool: True if the file exists; False if not. + True if the file exists; False if not. """ try: os.stat(file_path) @@ -102,15 +113,15 @@ class Config: A configuration section, containing configuration keys and values. Attributes: - section (str): The section title of this config object, such as + section: The section title of this config object, such as "tls" or "logger". This is used to refer to it on the root logger (for example, `config.tls.some_option`). Must be defined in subclasses. """ - section = None + section: str - def __init__(self, root_config=None): + def __init__(self, root_config: "RootConfig" = None): self.root = root_config # Get the path to the default Synapse template directory @@ -119,7 +130,7 @@ class Config: ) @staticmethod - def parse_size(value): + def parse_size(value: Union[str, int]) -> int: if isinstance(value, int): return value sizes = {"K": 1024, "M": 1024 * 1024} @@ -162,15 +173,15 @@ class Config: return int(value) * size @staticmethod - def abspath(file_path): + def abspath(file_path: str) -> str: return os.path.abspath(file_path) if file_path else file_path @classmethod - def path_exists(cls, file_path): + def path_exists(cls, file_path: str) -> bool: return path_exists(file_path) @classmethod - def check_file(cls, file_path, config_name): + def check_file(cls, file_path: Optional[str], config_name: str) -> str: if file_path is None: raise ConfigError("Missing config for %s." % (config_name,)) try: @@ -183,7 +194,7 @@ class Config: return cls.abspath(file_path) @classmethod - def ensure_directory(cls, dir_path): + def ensure_directory(cls, dir_path: str) -> str: dir_path = cls.abspath(dir_path) os.makedirs(dir_path, exist_ok=True) if not os.path.isdir(dir_path): @@ -191,7 +202,7 @@ class Config: return dir_path @classmethod - def read_file(cls, file_path, config_name): + def read_file(cls, file_path: Any, config_name: str) -> str: """Deprecated: call read_file directly""" return read_file(file_path, (config_name,)) @@ -284,6 +295,9 @@ class Config: return [env.get_template(filename) for filename in filenames] +TRootConfig = TypeVar("TRootConfig", bound="RootConfig") + + class RootConfig: """ Holder of an application's configuration. @@ -308,7 +322,9 @@ class RootConfig: raise Exception("Failed making %s: %r" % (config_class.section, e)) setattr(self, config_class.section, conf) - def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: + def invoke_all( + self, func_name: str, *args: Any, **kwargs: Any + ) -> MutableMapping[str, Any]: """ Invoke a function on all instantiated config objects this RootConfig is configured to use. @@ -317,6 +333,7 @@ class RootConfig: func_name: Name of function to invoke *args **kwargs + Returns: ordered dictionary of config section name and the result of the function from it. @@ -332,7 +349,7 @@ class RootConfig: return res @classmethod - def invoke_all_static(cls, func_name: str, *args, **kwargs): + def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None: """ Invoke a static function on config objects this RootConfig is configured to use. @@ -341,6 +358,7 @@ class RootConfig: func_name: Name of function to invoke *args **kwargs + Returns: ordered dictionary of config section name and the result of the function from it. @@ -351,16 +369,16 @@ class RootConfig: def generate_config( self, - config_dir_path, - data_dir_path, - server_name, - generate_secrets=False, - report_stats=None, - open_private_ports=False, - listeners=None, - tls_certificate_path=None, - tls_private_key_path=None, - ): + config_dir_path: str, + data_dir_path: str, + server_name: str, + generate_secrets: bool = False, + report_stats: Optional[bool] = None, + open_private_ports: bool = False, + listeners: Optional[List[dict]] = None, + tls_certificate_path: Optional[str] = None, + tls_private_key_path: Optional[str] = None, + ) -> str: """ Build a default configuration file @@ -368,27 +386,27 @@ class RootConfig: (eg with --generate_config). Args: - config_dir_path (str): The path where the config files are kept. Used to + config_dir_path: The path where the config files are kept. Used to create filenames for things like the log config and the signing key. - data_dir_path (str): The path where the data files are kept. Used to create + data_dir_path: The path where the data files are kept. Used to create filenames for things like the database and media store. - server_name (str): The server name. Used to initialise the server_name + server_name: The server name. Used to initialise the server_name config param, but also used in the names of some of the config files. - generate_secrets (bool): True if we should generate new secrets for things + generate_secrets: True if we should generate new secrets for things like the macaroon_secret_key. If False, these parameters will be left unset. - report_stats (bool|None): Initial setting for the report_stats setting. + report_stats: Initial setting for the report_stats setting. If None, report_stats will be left unset. - open_private_ports (bool): True to leave private ports (such as the non-TLS + open_private_ports: True to leave private ports (such as the non-TLS HTTP listener) open to the internet. - listeners (list(dict)|None): A list of descriptions of the listeners - synapse should start with each of which specifies a port (str), a list of + listeners: A list of descriptions of the listeners synapse should + start with each of which specifies a port (int), a list of resources (list(str)), tls (bool) and type (str). For example: [{ "port": 8448, @@ -403,16 +421,12 @@ class RootConfig: "type": "http", }], + tls_certificate_path: The path to the tls certificate. - database (str|None): The database type to configure, either `psycog2` - or `sqlite3`. - - tls_certificate_path (str|None): The path to the tls certificate. - - tls_private_key_path (str|None): The path to the tls private key. + tls_private_key_path: The path to the tls private key. Returns: - str: the yaml config file + The yaml config file """ return CONFIG_FILE_HEADER + "\n\n".join( @@ -432,12 +446,15 @@ class RootConfig: ) @classmethod - def load_config(cls, description, argv): + def load_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> TRootConfig: """Parse the commandline and config files Doesn't support config-file-generation: used by the worker apps. - Returns: Config object. + Returns: + Config object. """ config_parser = argparse.ArgumentParser(description=description) cls.add_arguments_to_parser(config_parser) @@ -446,7 +463,7 @@ class RootConfig: return obj @classmethod - def add_arguments_to_parser(cls, config_parser): + def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None: """Adds all the config flags to an ArgumentParser. Doesn't support config-file-generation: used by the worker apps. @@ -454,7 +471,7 @@ class RootConfig: Used for workers where we want to add extra flags/subcommands. Args: - config_parser (ArgumentParser): App description + config_parser: App description """ config_parser.add_argument( @@ -477,7 +494,9 @@ class RootConfig: cls.invoke_all_static("add_arguments", config_parser) @classmethod - def load_config_with_parser(cls, parser, argv): + def load_config_with_parser( + cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str] + ) -> Tuple[TRootConfig, argparse.Namespace]: """Parse the commandline and config files with the given parser Doesn't support config-file-generation: used by the worker apps. @@ -485,13 +504,12 @@ class RootConfig: Used for workers where we want to add extra flags/subcommands. Args: - parser (ArgumentParser) - argv (list[str]) + parser + argv Returns: - tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed - config object and the parsed argparse.Namespace object from - `parser.parse_args(..)` + Returns the parsed config object and the parsed argparse.Namespace + object from parser.parse_args(..)` """ obj = cls() @@ -520,12 +538,15 @@ class RootConfig: return obj, config_args @classmethod - def load_or_generate_config(cls, description, argv): + def load_or_generate_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> Optional[TRootConfig]: """Parse the commandline and config files Supports generation of config files, so is used for the main homeserver app. - Returns: Config object, or None if --generate-config or --generate-keys was set + Returns: + Config object, or None if --generate-config or --generate-keys was set """ parser = argparse.ArgumentParser(description=description) parser.add_argument( @@ -680,16 +701,21 @@ class RootConfig: return obj - def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): + def parse_config_dict( + self, + config_dict: Dict[str, Any], + config_dir_path: Optional[str] = None, + data_dir_path: Optional[str] = None, + ) -> None: """Read the information from the config dict into this Config object. Args: - config_dict (dict): Configuration data, as read from the yaml + config_dict: Configuration data, as read from the yaml - config_dir_path (str): The path where the config files are kept. Used to + config_dir_path: The path where the config files are kept. Used to create filenames for things like the log config and the signing key. - data_dir_path (str): The path where the data files are kept. Used to create + data_dir_path: The path where the data files are kept. Used to create filenames for things like the database and media store. """ self.invoke_all( @@ -699,17 +725,20 @@ class RootConfig: data_dir_path=data_dir_path, ) - def generate_missing_files(self, config_dict, config_dir_path): + def generate_missing_files( + self, config_dict: Dict[str, Any], config_dir_path: str + ) -> None: self.invoke_all("generate_files", config_dict, config_dir_path) -def read_config_files(config_files): +def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: """Read the config files into a dict Args: - config_files (iterable[str]): A list of the config files to read + config_files: A list of the config files to read - Returns: dict + Returns: + The configuration dictionary. """ specified_config = {} for config_file in config_files: @@ -733,17 +762,17 @@ def read_config_files(config_files): return specified_config -def find_config_files(search_paths): +def find_config_files(search_paths: List[str]) -> List[str]: """Finds config files using a list of search paths. If a path is a file then that file path is added to the list. If a search path is a directory then all the "*.yaml" files in that directory are added to the list in sorted order. Args: - search_paths(list(str)): A list of paths to search. + search_paths: A list of paths to search. Returns: - list(str): A list of file paths. + A list of file paths. """ config_files = [] @@ -777,7 +806,7 @@ def find_config_files(search_paths): return config_files -@attr.s +@attr.s(auto_attribs=True) class ShardedWorkerHandlingConfig: """Algorithm for choosing which instance is responsible for handling some sharded work. @@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig: below). """ - instances = attr.ib(type=List[str]) + instances: List[str] def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key.""" diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index c1d9069798..1eb5f5a68c 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,4 +1,18 @@ -from typing import Any, Iterable, List, Optional +import argparse +from typing import ( + Any, + Dict, + Iterable, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import jinja2 from synapse.config import ( account_validity, @@ -19,6 +33,7 @@ from synapse.config import ( logger, metrics, modules, + oembed, oidc, password_auth_providers, push, @@ -27,6 +42,7 @@ from synapse.config import ( registration, repository, retention, + room, room_directory, saml2, server, @@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_SPIEL: str MISSING_SERVER_NAME: str -def path_exists(file_path: str): ... +def path_exists(file_path: str) -> bool: ... + +TRootConfig = TypeVar("TRootConfig", bound="RootConfig") class RootConfig: server: server.ServerConfig @@ -61,6 +79,7 @@ class RootConfig: logging: logger.LoggingConfig ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig + oembed: oembed.OembedConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig @@ -80,6 +99,7 @@ class RootConfig: authproviders: password_auth_providers.PasswordAuthProviderConfig push: push.PushConfig spamchecker: spam_checker.SpamCheckerConfig + room: room.RoomConfig groups: groups.GroupsConfig userdirectory: user_directory.UserDirectoryConfig consent: consent.ConsentConfig @@ -87,72 +107,85 @@ class RootConfig: servernotices: server_notices.ServerNoticesConfig roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig - tracer: tracer.TracerConfig + tracing: tracer.TracerConfig redis: redis.RedisConfig modules: modules.ModulesConfig caches: cache.CacheConfig federation: federation.FederationConfig retention: retention.RetentionConfig - config_classes: List = ... + config_classes: List[Type["Config"]] = ... def __init__(self) -> None: ... - def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ... + def invoke_all( + self, func_name: str, *args: Any, **kwargs: Any + ) -> MutableMapping[str, Any]: ... @classmethod def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... - def __getattr__(self, item: str): ... def parse_config_dict( self, - config_dict: Any, - config_dir_path: Optional[Any] = ..., - data_dir_path: Optional[Any] = ..., + config_dict: Dict[str, Any], + config_dir_path: Optional[str] = ..., + data_dir_path: Optional[str] = ..., ) -> None: ... - read_config: Any = ... def generate_config( self, config_dir_path: str, data_dir_path: str, server_name: str, generate_secrets: bool = ..., - report_stats: Optional[str] = ..., + report_stats: Optional[bool] = ..., open_private_ports: bool = ..., listeners: Optional[Any] = ..., - database_conf: Optional[Any] = ..., tls_certificate_path: Optional[str] = ..., tls_private_key_path: Optional[str] = ..., - ): ... + ) -> str: ... @classmethod - def load_or_generate_config(cls, description: Any, argv: Any): ... + def load_or_generate_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> Optional[TRootConfig]: ... @classmethod - def load_config(cls, description: Any, argv: Any): ... + def load_config( + cls: Type[TRootConfig], description: str, argv: List[str] + ) -> TRootConfig: ... @classmethod - def add_arguments_to_parser(cls, config_parser: Any) -> None: ... + def add_arguments_to_parser( + cls, config_parser: argparse.ArgumentParser + ) -> None: ... @classmethod - def load_config_with_parser(cls, parser: Any, argv: Any): ... + def load_config_with_parser( + cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str] + ) -> Tuple[TRootConfig, argparse.Namespace]: ... def generate_missing_files( self, config_dict: dict, config_dir_path: str ) -> None: ... class Config: root: RootConfig + default_template_dir: str def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... - def __getattr__(self, item: str, from_root: bool = ...): ... @staticmethod - def parse_size(value: Any): ... + def parse_size(value: Union[str, int]) -> int: ... @staticmethod - def parse_duration(value: Any): ... + def parse_duration(value: Union[str, int]) -> int: ... @staticmethod - def abspath(file_path: Optional[str]): ... + def abspath(file_path: Optional[str]) -> str: ... @classmethod - def path_exists(cls, file_path: str): ... + def path_exists(cls, file_path: str) -> bool: ... @classmethod - def check_file(cls, file_path: str, config_name: str): ... + def check_file(cls, file_path: str, config_name: str) -> str: ... @classmethod - def ensure_directory(cls, dir_path: str): ... + def ensure_directory(cls, dir_path: str) -> str: ... @classmethod - def read_file(cls, file_path: str, config_name: str): ... + def read_file(cls, file_path: str, config_name: str) -> str: ... + def read_template(self, filenames: str) -> jinja2.Template: ... + def read_templates( + self, + filenames: List[str], + custom_template_directories: Optional[Iterable[str]] = None, + ) -> List[jinja2.Template]: ... -def read_config_files(config_files: List[str]): ... -def find_config_files(search_paths: List[str]): ... +def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ... +def find_config_files(search_paths: List[str]) -> List[str]: ... class ShardedWorkerHandlingConfig: instances: List[str] diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 1ebea88db2..e4bb7224a4 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +14,14 @@ # limitations under the License. import logging -from typing import Dict +from typing import Dict, List from urllib import parse as urlparse import yaml from netaddr import IPSet from synapse.appservice import ApplicationService -from synapse.types import UserID +from synapse.types import JsonDict, UserID from ._base import Config, ConfigError @@ -30,12 +31,12 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): section = "appservice" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) - def generate_config_section(cls, **kwargs): + def generate_config_section(cls, **kwargs) -> str: return """\ # A list of application service config files to use # @@ -50,7 +51,9 @@ class AppServiceConfig(Config): """ -def load_appservices(hostname, config_files): +def load_appservices( + hostname: str, config_files: List[str] +) -> List[ApplicationService]: """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): logger.warning("Expected %s to be a list of AS config files.", config_files) @@ -93,7 +96,9 @@ def load_appservices(hostname, config_files): return appservices -def _load_appservice(hostname, as_info, config_filename): +def _load_appservice( + hostname: str, as_info: JsonDict, config_filename: str +) -> ApplicationService: required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: if not isinstance(as_info.get(field), str): @@ -115,9 +120,9 @@ def _load_appservice(hostname, as_info, config_filename): user_id = user.to_string() # Rate limiting for users of this AS is on by default (excludes sender) - rate_limited = True - if isinstance(as_info.get("rate_limited"), bool): - rate_limited = as_info.get("rate_limited") + rate_limited = as_info.get("rate_limited") + if not isinstance(rate_limited, bool): + rate_limited = True # namespace checks if not isinstance(as_info.get("namespaces"), dict): diff --git a/synapse/config/cache.py b/synapse/config/cache.py index d119427ad8..d9d85f98e1 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -1,4 +1,4 @@ -# Copyright 2019 Matrix.org Foundation C.I.C. +# Copyright 2019-2021 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,9 @@ import os import re import threading -from typing import Callable, Dict +from typing import Callable, Dict, Optional + +import attr from synapse.python_dependencies import DependencyException, check_requirements @@ -34,13 +36,13 @@ _DEFAULT_FACTOR_SIZE = 0.5 _DEFAULT_EVENT_CACHE_SIZE = "10K" +@attr.s(slots=True, auto_attribs=True) class CacheProperties: - def __init__(self): - # The default factor size for all caches - self.default_factor_size = float( - os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) - ) - self.resize_all_caches_func = None + # The default factor size for all caches + default_factor_size: float = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + resize_all_caches_func: Optional[Callable[[], None]] = None properties = CacheProperties() @@ -62,7 +64,7 @@ def _canonicalise_cache_name(cache_name: str) -> str: def add_resizable_cache( cache_name: str, cache_resize_callback: Callable[[float], None] -): +) -> None: """Register a cache that's size can dynamically change Args: @@ -91,7 +93,7 @@ class CacheConfig(Config): _environ = os.environ @staticmethod - def reset(): + def reset() -> None: """Resets the caches to their defaults. Used for tests.""" properties.default_factor_size = float( os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) @@ -100,7 +102,7 @@ class CacheConfig(Config): with _CACHES_LOCK: _CACHES.clear() - def generate_config_section(self, **kwargs): + def generate_config_section(self, **kwargs) -> str: return """\ ## Caching ## @@ -162,7 +164,7 @@ class CacheConfig(Config): #sync_response_cache_duration: 2m """ - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.event_cache_size = self.parse_size( config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) ) @@ -217,7 +219,7 @@ class CacheConfig(Config): expiry_time = cache_config.get("expiry_time") if expiry_time: - self.expiry_time_msec = self.parse_duration(expiry_time) + self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time) else: self.expiry_time_msec = None @@ -232,7 +234,7 @@ class CacheConfig(Config): # needing an instance of Config properties.resize_all_caches_func = self.resize_all_caches - def resize_all_caches(self): + def resize_all_caches(self) -> None: """Ensure all cache sizes are up to date For each cache, run the mapped callback function with either diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 3f81814043..6f2754092e 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +29,7 @@ class CasConfig(Config): section = "cas" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: cas_config = config.get("cas_config", None) self.cas_enabled = cas_config and cas_config.get("enabled", True) @@ -51,7 +52,7 @@ class CasConfig(Config): self.cas_displayname_attribute = None self.cas_required_attributes = [] - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: return """\ # Enable Central Authentication Service (CAS) for registration and login. # diff --git a/synapse/config/database.py b/synapse/config/database.py index 651e31b576..06ccf15cd9 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -1,5 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import os @@ -119,7 +120,7 @@ class DatabaseConfig(Config): self.databases = [] - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra @@ -163,12 +164,12 @@ class DatabaseConfig(Config): self.databases = [DatabaseConnectionConfig("master", database_config)] self.set_databasepath(database_path) - def generate_config_section(self, data_dir_path, **kwargs): + def generate_config_section(self, data_dir_path, **kwargs) -> str: return DEFAULT_CONFIG % { "database_path": os.path.join(data_dir_path, "homeserver.db") } - def read_arguments(self, args): + def read_arguments(self, args: argparse.Namespace) -> None: """ Cases for the cli input: - If no databases are configured and no database_path is set, raise. @@ -194,7 +195,7 @@ class DatabaseConfig(Config): else: logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) - def set_databasepath(self, database_path): + def set_databasepath(self, database_path: str) -> None: if database_path != ":memory:": database_path = self.abspath(database_path) @@ -202,7 +203,7 @@ class DatabaseConfig(Config): self.databases[0].config["args"]["database"] = database_path @staticmethod - def add_arguments(parser): + def add_arguments(parser: argparse.ArgumentParser) -> None: db_group = parser.add_argument_group("database") db_group.add_argument( "-d", diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index afd65fecd3..510b647c63 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -137,33 +137,14 @@ class EmailConfig(Config): if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) - # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would - # use an identity server to password reset tokens on its behalf. We now warn the user - # if they have this set and tell them to use the updated option, while using a default - # identity server in the process. - self.using_identity_server_from_trusted_list = False - if ( - not self.root.registration.account_threepid_delegate_email - and config.get("trust_identity_server_for_password_resets", False) is True - ): - # Use the first entry in self.trusted_third_party_id_servers instead - if self.trusted_third_party_id_servers: - # XXX: It's a little confusing that account_threepid_delegate_email is modified - # both in RegistrationConfig and here. We should factor this bit out - first_trusted_identity_server = self.trusted_third_party_id_servers[0] - - # trusted_third_party_id_servers does not contain a scheme whereas - # account_threepid_delegate_email is expected to. Presume https - self.root.registration.account_threepid_delegate_email = ( - "https://" + first_trusted_identity_server - ) - self.using_identity_server_from_trusted_list = True - else: - raise ConfigError( - "Attempted to use an identity server from" - '"trusted_third_party_id_servers" but it is empty.' - ) + if config.get("trust_identity_server_for_password_resets"): + raise ConfigError( + 'The config option "trust_identity_server_for_password_resets" ' + 'has been replaced by "account_threepid_delegate". ' + "Please consult the sample config at docs/sample_config.yaml for " + "details and update your config file." + ) self.local_threepid_handling_disabled_due_to_email_config = False if ( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 8b098ad48d..d78a15097c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -46,3 +46,6 @@ class ExperimentalConfig(Config): # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) + + # MSC3030 (Jump to date API endpoint) + self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index 9d295f5856..24c3ef01fc 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -31,6 +31,8 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. self.jwt_issuer = jwt_config.get("issuer") @@ -46,6 +48,7 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_subject_claim = None self.jwt_issuer = None self.jwt_audiences = None @@ -88,6 +91,12 @@ class JWTConfig(Config): # #algorithm: "provided-by-your-issuer" + # Name of the claim containing a unique identifier for the user. + # + # Optional, defaults to `sub`. + # + #subject_claim: "sub" + # The issuer to validate the "iss" claim against. # # Optional, if provided the "iss" claim will be required and diff --git a/synapse/config/key.py b/synapse/config/key.py index 015dbb8a67..035ee2416b 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -16,6 +16,7 @@ import hashlib import logging import os +from typing import Any, Dict import attr import jsonschema @@ -312,7 +313,7 @@ class KeyConfig(Config): ) return keys - def generate_files(self, config, config_dir_path): + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: if "signing_key" in config: return diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 5252e61a99..ea69b9bd9b 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +19,7 @@ import os import sys import threading from string import Template -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional import yaml from zope.interface import implementer @@ -40,6 +41,7 @@ from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig from synapse.server import HomeServer DEFAULT_LOG_CONFIG = Template( @@ -141,13 +143,13 @@ removed in Synapse 1.3.0. You should instead set up a separate log configuration class LoggingConfig(Config): section = "logging" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: if config.get("log_file"): raise ConfigError(LOG_FILE_ERROR) self.log_config = self.abspath(config.get("log_config")) self.no_redirect_stdio = config.get("no_redirect_stdio", False) - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: log_config = os.path.join(config_dir_path, server_name + ".log.config") return ( """\ @@ -161,14 +163,14 @@ class LoggingConfig(Config): % locals() ) - def read_arguments(self, args): + def read_arguments(self, args: argparse.Namespace) -> None: if args.no_redirect_stdio is not None: self.no_redirect_stdio = args.no_redirect_stdio if args.log_file is not None: raise ConfigError(LOG_FILE_ERROR) @staticmethod - def add_arguments(parser): + def add_arguments(parser: argparse.ArgumentParser) -> None: logging_group = parser.add_argument_group("logging") logging_group.add_argument( "-n", @@ -185,7 +187,7 @@ class LoggingConfig(Config): help=argparse.SUPPRESS, ) - def generate_files(self, config, config_dir_path): + def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None: log_config = config.get("log_config") if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") @@ -197,7 +199,9 @@ class LoggingConfig(Config): log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None: +def _setup_stdlib_logging( + config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner +) -> None: """ Set up Python standard library logging. """ @@ -230,7 +234,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> log_metadata_filter = MetadataFilter({"server_name": config.server.server_name}) old_factory = logging.getLogRecordFactory() - def factory(*args, **kwargs): + def factory(*args: Any, **kwargs: Any) -> logging.LogRecord: record = old_factory(*args, **kwargs) log_context_filter.filter(record) log_metadata_filter.filter(record) @@ -297,7 +301,7 @@ def _load_logging_config(log_config_path: str) -> None: logging.config.dictConfig(log_config) -def _reload_logging_config(log_config_path): +def _reload_logging_config(log_config_path: Optional[str]) -> None: """ Reload the log configuration from the file and apply it. """ @@ -311,8 +315,8 @@ def _reload_logging_config(log_config_path): def setup_logging( hs: "HomeServer", - config, - use_worker_options=False, + config: "HomeServerConfig", + use_worker_options: bool = False, logBeginner: LogBeginner = globalLogBeginner, ) -> None: """ diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 42f113cd24..79c400fe30 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import Counter -from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type +from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type import attr @@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr class OIDCConfig(Config): section = "oidc" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) if not self.oidc_providers: return @@ -66,7 +66,7 @@ class OIDCConfig(Config): # OIDC is enabled if we have a provider return bool(self.oidc_providers) - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: return """\ # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration # and login. @@ -495,89 +495,89 @@ def _parse_oidc_config_dict( ) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class OidcProviderClientSecretJwtKey: # a pem-encoded signing key - key = attr.ib(type=str) + key: str # properties to include in the JWT header - jwt_header = attr.ib(type=Mapping[str, str]) + jwt_header: Mapping[str, str] # properties to include in the JWT payload. - jwt_payload = attr.ib(type=Mapping[str, str]) + jwt_payload: Mapping[str, str] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class OidcProviderConfig: # a unique identifier for this identity provider. Used in the 'user_external_ids' # table, as well as the query/path parameter used in the login protocol. - idp_id = attr.ib(type=str) + idp_id: str # user-facing name for this identity provider. - idp_name = attr.ib(type=str) + idp_name: str # Optional MXC URI for icon for this IdP. - idp_icon = attr.ib(type=Optional[str]) + idp_icon: Optional[str] # Optional brand identifier for this IdP. - idp_brand = attr.ib(type=Optional[str]) + idp_brand: Optional[str] # whether the OIDC discovery mechanism is used to discover endpoints - discover = attr.ib(type=bool) + discover: bool # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to # discover the provider's endpoints. - issuer = attr.ib(type=str) + issuer: str # oauth2 client id to use - client_id = attr.ib(type=str) + client_id: str # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate # a secret. - client_secret = attr.ib(type=Optional[str]) + client_secret: Optional[str] # key to use to construct a JWT to use as a client secret. May be `None` if # `client_secret` is set. - client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey]) + client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] # auth method to use when exchanging the token. # Valid values are 'client_secret_basic', 'client_secret_post' and # 'none'. - client_auth_method = attr.ib(type=str) + client_auth_method: str # list of scopes to request - scopes = attr.ib(type=Collection[str]) + scopes: Collection[str] # the oauth2 authorization endpoint. Required if discovery is disabled. - authorization_endpoint = attr.ib(type=Optional[str]) + authorization_endpoint: Optional[str] # the oauth2 token endpoint. Required if discovery is disabled. - token_endpoint = attr.ib(type=Optional[str]) + token_endpoint: Optional[str] # the OIDC userinfo endpoint. Required if discovery is disabled and the # "openid" scope is not requested. - userinfo_endpoint = attr.ib(type=Optional[str]) + userinfo_endpoint: Optional[str] # URI where to fetch the JWKS. Required if discovery is disabled and the # "openid" scope is used. - jwks_uri = attr.ib(type=Optional[str]) + jwks_uri: Optional[str] # Whether to skip metadata verification - skip_verification = attr.ib(type=bool) + skip_verification: bool # Whether to fetch the user profile from the userinfo endpoint. Valid # values are: "auto" or "userinfo_endpoint". - user_profile_method = attr.ib(type=str) + user_profile_method: str # whether to allow a user logging in via OIDC to match a pre-existing account # instead of failing - allow_existing_users = attr.ib(type=bool) + allow_existing_users: bool # the class of the user mapping provider - user_mapping_provider_class = attr.ib(type=Type) + user_mapping_provider_class: Type # the config of the user mapping provider - user_mapping_provider_config = attr.ib() + user_mapping_provider_config: Any # required attributes to require in userinfo to allow login/registration - attribute_requirements = attr.ib(type=List[SsoAttributeRequirement]) + attribute_requirements: List[SsoAttributeRequirement] diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 5379e80715..7a059c6dec 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse +from typing import Optional from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError @@ -39,9 +42,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) - self.trusted_third_party_id_servers = config.get( - "trusted_third_party_id_servers", ["matrix.org", "vector.im"] - ) + account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") @@ -114,26 +115,74 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime - # The `access_token_lifetime` applies for tokens that can be renewed - # using a refresh token, as per MSC2918. If it is `None`, the refresh - # token mechanism is disabled. - # - # Since it is incompatible with the `session_lifetime` mechanism, it is set to - # `None` by default if a `session_lifetime` is set. - access_token_lifetime = config.get( - "access_token_lifetime", "5m" if session_lifetime is None else None + # The `refreshable_access_token_lifetime` applies for tokens that can be renewed + # using a refresh token, as per MSC2918. + # If it is `None`, the refresh token mechanism is disabled. + refreshable_access_token_lifetime = config.get( + "refreshable_access_token_lifetime", + "5m", + ) + if refreshable_access_token_lifetime is not None: + refreshable_access_token_lifetime = self.parse_duration( + refreshable_access_token_lifetime + ) + self.refreshable_access_token_lifetime: Optional[ + int + ] = refreshable_access_token_lifetime + + if ( + self.session_lifetime is not None + and "refreshable_access_token_lifetime" in config + ): + if self.session_lifetime < self.refreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refreshable_access_token_lifetime` " + "configuration options have been set, but `refreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + + # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be + # refreshed using a refresh token. + # If it is None, then these tokens last for the entire length of the session, + # which is infinite by default. + # The intention behind this configuration option is to help with requiring + # all clients to use refresh tokens, if the homeserver administrator requires. + nonrefreshable_access_token_lifetime = config.get( + "nonrefreshable_access_token_lifetime", + None, ) - if access_token_lifetime is not None: - access_token_lifetime = self.parse_duration(access_token_lifetime) - self.access_token_lifetime = access_token_lifetime - - if session_lifetime is not None and access_token_lifetime is not None: - raise ConfigError( - "The refresh token mechanism is incompatible with the " - "`session_lifetime` option. Consider disabling the " - "`session_lifetime` option or disabling the refresh token " - "mechanism by removing the `access_token_lifetime` option." + if nonrefreshable_access_token_lifetime is not None: + nonrefreshable_access_token_lifetime = self.parse_duration( + nonrefreshable_access_token_lifetime ) + self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime + + if ( + self.session_lifetime is not None + and self.nonrefreshable_access_token_lifetime is not None + ): + if self.session_lifetime < self.nonrefreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` " + "configuration options have been set, but `nonrefreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + + refresh_token_lifetime = config.get("refresh_token_lifetime") + if refresh_token_lifetime is not None: + refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) + self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime + + if ( + self.session_lifetime is not None + and self.refresh_token_lifetime is not None + ): + if self.session_lifetime < self.refresh_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refresh_token_lifetime` " + "configuration options have been set, but `refresh_token_lifetime` " + " exceeds `session_lifetime`!" + ) # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") @@ -171,6 +220,44 @@ class RegistrationConfig(Config): # #session_lifetime: 24h + # Time that an access token remains valid for, if the session is + # using refresh tokens. + # For more information about refresh tokens, please see the manual. + # Note that this only applies to clients which advertise support for + # refresh tokens. + # + # Note also that this is calculated at login time and refresh time: + # changes are not applied to existing sessions until they are refreshed. + # + # By default, this is 5 minutes. + # + #refreshable_access_token_lifetime: 5m + + # Time that a refresh token remains valid for (provided that it is not + # exchanged for another one first). + # This option can be used to automatically log-out inactive sessions. + # Please see the manual for more information. + # + # Note also that this is calculated at login time and refresh time: + # changes are not applied to existing sessions until they are refreshed. + # + # By default, this is infinite. + # + #refresh_token_lifetime: 24h + + # Time that an access token remains valid for, if the session is NOT + # using refresh tokens. + # Please note that not all clients support refresh tokens, so setting + # this to a short value may be inconvenient for some users who will + # then be logged out frequently. + # + # Note also that this is calculated at login time: changes are not applied + # retrospectively to existing sessions for users that have already logged in. + # + # By default, this is infinite. + # + #nonrefreshable_access_token_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: @@ -364,7 +451,7 @@ class RegistrationConfig(Config): ) @staticmethod - def add_arguments(parser): + def add_arguments(parser: argparse.ArgumentParser) -> None: reg_group = parser.add_argument_group("registration") reg_group.add_argument( "--enable-registration", @@ -373,6 +460,6 @@ class RegistrationConfig(Config): help="Enable registration for new users.", ) - def read_arguments(self, args): + def read_arguments(self, args: argparse.Namespace) -> None: if args.enable_registration is not None: self.enable_registration = strtobool(str(args.enable_registration)) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 69906a98d4..b129b9dd68 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -15,11 +15,12 @@ import logging import os from collections import namedtuple -from typing import Dict, List +from typing import Dict, List, Tuple from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements +from synapse.types import JsonDict from synapse.util.module_loader import load_module from ._base import Config, ConfigError @@ -57,7 +58,9 @@ MediaStorageProviderConfig = namedtuple( ) -def parse_thumbnail_requirements(thumbnail_sizes): +def parse_thumbnail_requirements( + thumbnail_sizes: List[JsonDict], +) -> Dict[str, Tuple[ThumbnailRequirement, ...]]: """Takes a list of dictionaries with "width", "height", and "method" keys and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate @@ -69,7 +72,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): Dictionary mapping from media type string to list of ThumbnailRequirement tuples. """ - requirements: Dict[str, List] = {} + requirements: Dict[str, List[ThumbnailRequirement]] = {} for size in thumbnail_sizes: width = size["width"] height = size["height"] diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 56981cac79..57316c59b6 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -1,4 +1,5 @@ # Copyright 2018 New Vector Ltd +# Copyright 2021 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +from synapse.types import JsonDict from synapse.util import glob_to_regex from ._base import Config, ConfigError @@ -20,7 +24,7 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): section = "roomdirectory" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") @@ -47,7 +51,7 @@ class RoomDirectoryConfig(Config): _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) ] - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: return """ # Uncomment to disable searching the public room list. When disabled # blocks searching local and remote room lists for local and remote @@ -113,16 +117,16 @@ class RoomDirectoryConfig(Config): # action: allow """ - def is_alias_creation_allowed(self, user_id, room_id, alias): + def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool: """Checks if the given user is allowed to create the given alias Args: - user_id (str) - room_id (str) - alias (str) + user_id: The user to check. + room_id: The room ID for the alias. + alias: The alias being created. Returns: - boolean: True if user is allowed to create the alias + True if user is allowed to create the alias """ for rule in self._alias_creation_rules: if rule.matches(user_id, room_id, [alias]): @@ -130,16 +134,18 @@ class RoomDirectoryConfig(Config): return False - def is_publishing_room_allowed(self, user_id, room_id, aliases): + def is_publishing_room_allowed( + self, user_id: str, room_id: str, aliases: List[str] + ) -> bool: """Checks if the given user is allowed to publish the room Args: - user_id (str) - room_id (str) - aliases (list[str]): any local aliases associated with the room + user_id: The user ID publishing the room. + room_id: The room being published. + aliases: any local aliases associated with the room Returns: - boolean: True if user can publish room + True if user can publish room """ for rule in self._room_list_publication_rules: if rule.matches(user_id, room_id, aliases): @@ -153,11 +159,11 @@ class _RoomDirectoryRule: creating an alias or publishing a room. """ - def __init__(self, option_name, rule): + def __init__(self, option_name: str, rule: JsonDict): """ Args: - option_name (str): Name of the config option this rule belongs to - rule (dict): The rule as specified in the config + option_name: Name of the config option this rule belongs to + rule: The rule as specified in the config """ action = rule["action"] @@ -181,18 +187,18 @@ class _RoomDirectoryRule: except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id, room_id, aliases): + def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: - user_id (str) - room_id (str) - aliases (list[str]): The associated aliases to the room. Will be a - single element for testing alias creation, and can be empty for - testing room publishing. + user_id: The user ID to check. + room_id: The room ID to check. + aliases: The associated aliases to the room. Will be a single element + for testing alias creation, and can be empty for testing room + publishing. Returns: - boolean + True if the rule matches. """ # Note: The regexes are anchored at both ends diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index ba2b0905ff..ec9d9f65e7 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -1,5 +1,5 @@ # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +14,11 @@ # limitations under the License. import logging -from typing import Any, List +from typing import Any, List, Set from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements +from synapse.types import JsonDict from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError @@ -33,7 +34,7 @@ LEGACY_USER_MAPPING_PROVIDER = ( ) -def _dict_merge(merge_dict, into_dict): +def _dict_merge(merge_dict: dict, into_dict: dict) -> None: """Do a deep merge of two dicts Recursively merges `merge_dict` into `into_dict`: @@ -43,8 +44,8 @@ def _dict_merge(merge_dict, into_dict): the value from `merge_dict`. Args: - merge_dict (dict): dict to merge - into_dict (dict): target dict + merge_dict: dict to merge + into_dict: target dict to be modified """ for k, v in merge_dict.items(): if k not in into_dict: @@ -64,7 +65,7 @@ def _dict_merge(merge_dict, into_dict): class SAML2Config(Config): section = "saml2" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: self.saml2_enabled = False saml2_config = config.get("saml2_config") @@ -183,8 +184,8 @@ class SAML2Config(Config): ) def _default_saml_config_dict( - self, required_attributes: set, optional_attributes: set - ): + self, required_attributes: Set[str], optional_attributes: Set[str] + ) -> JsonDict: """Generate a configuration dictionary with required and optional attributes that will be needed to process new user registration @@ -195,7 +196,7 @@ class SAML2Config(Config): additional information to Synapse user accounts, but are not required Returns: - dict: A SAML configuration dictionary + A SAML configuration dictionary """ import saml2 @@ -222,7 +223,7 @@ class SAML2Config(Config): }, } - def generate_config_section(self, config_dir_path, server_name, **kwargs): + def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str: return """\ ## Single sign-on integration ## diff --git a/synapse/config/server.py b/synapse/config/server.py index 7bc0030a9e..ba5b954263 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import itertools import logging import os.path @@ -27,6 +28,7 @@ from netaddr import AddrFormatError, IPNetwork, IPSet from twisted.conch.ssh.keys import Key from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.types import JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name @@ -421,7 +423,7 @@ class ServerConfig(Config): # before redacting them. redaction_retention_period = config.get("redaction_retention_period", "7d") if redaction_retention_period is not None: - self.redaction_retention_period = self.parse_duration( + self.redaction_retention_period: Optional[int] = self.parse_duration( redaction_retention_period ) else: @@ -430,7 +432,7 @@ class ServerConfig(Config): # How long to keep entries in the `users_ips` table. user_ips_max_age = config.get("user_ips_max_age", "28d") if user_ips_max_age is not None: - self.user_ips_max_age = self.parse_duration(user_ips_max_age) + self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age) else: self.user_ips_max_age = None @@ -1223,7 +1225,7 @@ class ServerConfig(Config): % locals() ) - def read_arguments(self, args): + def read_arguments(self, args: argparse.Namespace) -> None: if args.manhole is not None: self.manhole = args.manhole if args.daemonize is not None: @@ -1232,7 +1234,7 @@ class ServerConfig(Config): self.print_pidfile = args.print_pidfile @staticmethod - def add_arguments(parser): + def add_arguments(parser: argparse.ArgumentParser) -> None: server_group = parser.add_argument_group("server") server_group.add_argument( "-D", @@ -1274,14 +1276,16 @@ class ServerConfig(Config): ) -def is_threepid_reserved(reserved_threepids, threepid): +def is_threepid_reserved( + reserved_threepids: List[JsonDict], threepid: JsonDict +) -> bool: """Check the threepid against the reserved threepid config Args: - reserved_threepids([dict]) - list of reserved threepids - threepid(dict) - The threepid to test for + reserved_threepids: List of reserved threepids + threepid: The threepid to test for Returns: - boolean Is the threepid undertest reserved_user + Is the threepid undertest reserved_user """ for tp in reserved_threepids: @@ -1290,7 +1294,9 @@ def is_threepid_reserved(reserved_threepids, threepid): return False -def read_gc_thresholds(thresholds): +def read_gc_thresholds( + thresholds: Optional[List[Any]], +) -> Optional[Tuple[int, int, int]]: """Reads the three integer thresholds for garbage collection. Ensures that the thresholds are integers if thresholds are supplied. """ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 60aacb13ea..e4a4243261 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -1,4 +1,4 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,13 +29,13 @@ https://matrix-org.github.io/synapse/latest/templates.html ---------------------------------------------------------------------------------------""" -@attr.s(frozen=True) +@attr.s(frozen=True, auto_attribs=True) class SsoAttributeRequirement: """Object describing a single requirement for SSO attributes.""" - attribute = attr.ib(type=str) + attribute: str # If a value is not given, than the attribute must simply exist. - value = attr.ib(type=Optional[str]) + value: Optional[str] JSON_SCHEMA = { "type": "object", @@ -49,7 +49,7 @@ class SSOConfig(Config): section = "sso" - def read_config(self, config, **kwargs): + def read_config(self, config, **kwargs) -> None: sso_config: Dict[str, Any] = config.get("sso") or {} # The sso-specific template_dir @@ -106,7 +106,7 @@ class SSOConfig(Config): ) self.sso_client_whitelist.append(login_fallback_url) - def generate_config_section(self, **kwargs): + def generate_config_section(self, **kwargs) -> str: return """\ # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 6227434bac..4ca111618f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -14,7 +14,6 @@ import logging import os -from datetime import datetime from typing import List, Optional, Pattern from OpenSSL import SSL, crypto @@ -133,55 +132,6 @@ class TlsConfig(Config): self.tls_certificate: Optional[crypto.X509] = None self.tls_private_key: Optional[crypto.PKey] = None - def is_disk_cert_valid(self, allow_self_signed=True): - """ - Is the certificate we have on disk valid, and if so, for how long? - - Args: - allow_self_signed (bool): Should we allow the certificate we - read to be self signed? - - Returns: - int: Days remaining of certificate validity. - None: No certificate exists. - """ - if not os.path.exists(self.tls_certificate_file): - return None - - try: - with open(self.tls_certificate_file, "rb") as f: - cert_pem = f.read() - except Exception as e: - raise ConfigError( - "Failed to read existing certificate file %s: %s" - % (self.tls_certificate_file, e) - ) - - try: - tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) - except Exception as e: - raise ConfigError( - "Failed to parse existing certificate file %s: %s" - % (self.tls_certificate_file, e) - ) - - if not allow_self_signed: - if tls_certificate.get_subject() == tls_certificate.get_issuer(): - raise ValueError( - "TLS Certificate is self signed, and this is not permitted" - ) - - # YYYYMMDDhhmmssZ -- in UTC - expiry_data = tls_certificate.get_notAfter() - if expiry_data is None: - raise ValueError( - "TLS Certificate has no expiry date, and this is not permitted" - ) - expires_on = datetime.strptime(expiry_data.decode("ascii"), "%Y%m%d%H%M%SZ") - now = datetime.utcnow() - days_remaining = (expires_on - now).days - return days_remaining - def read_certificate_from_disk(self): """ Read the certificates and private key from disk. @@ -263,8 +213,8 @@ class TlsConfig(Config): # #federation_certificate_verification_whitelist: # - lon.example.com - # - *.domain.com - # - *.onion + # - "*.domain.com" + # - "*.onion" # List of custom certificate authorities for federation traffic. # @@ -295,7 +245,7 @@ class TlsConfig(Config): cert_path = self.tls_certificate_file logger.info("Loading TLS certificate from %s", cert_path) cert_pem = self.read_file(cert_path, "tls_certificate_path") - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode()) return cert diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 2552f688d0..6d6678c7e4 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -53,8 +53,8 @@ class UserDirectoryConfig(Config): # indexes were (re)built was before Synapse 1.44, you'll have to # rebuild the indexes in order to search through all known users. # These indexes are built the first time Synapse starts; admins can - # manually trigger a rebuild following the instructions at - # https://matrix-org.github.io/synapse/latest/user_directory.html + # manually trigger a rebuild via API following the instructions at + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run # # Uncomment to return search results containing all known users, even if that # user does not share a room with the requester. diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 4507992031..576f519188 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse from typing import List, Union import attr @@ -343,7 +345,7 @@ class WorkerConfig(Config): #worker_replication_secret: "" """ - def read_arguments(self, args): + def read_arguments(self, args: argparse.Namespace) -> None: # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running # on workers so we also have to override them when command line options diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index f641ab7ef5..993b04099e 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017, 2018 New Vector Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -120,16 +119,6 @@ class VerifyJsonRequest: key_ids=key_ids, ) - def to_fetch_key_request(self) -> "_FetchKeyRequest": - """Create a key fetch request for all keys needed to satisfy the - verification request. - """ - return _FetchKeyRequest( - server_name=self.server_name, - minimum_valid_until_ts=self.minimum_valid_until_ts, - key_ids=self.key_ids, - ) - class KeyLookupError(ValueError): pass @@ -179,8 +168,22 @@ class Keyring: clock=hs.get_clock(), process_batch_callback=self._inner_fetch_key_requests, ) - self.verify_key = get_verify_key(hs.signing_key) - self.hostname = hs.hostname + + self._hostname = hs.hostname + + # build a FetchKeyResult for each of our own keys, to shortcircuit the + # fetcher. + self._local_verify_keys: Dict[str, FetchKeyResult] = {} + for key_id, key in hs.config.key.old_signing_keys.items(): + self._local_verify_keys[key_id] = FetchKeyResult( + verify_key=key, valid_until_ts=key.expired_ts + ) + + vk = get_verify_key(hs.signing_key) + self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult( + verify_key=vk, + valid_until_ts=2 ** 63, # fake future timestamp + ) async def verify_json_for_server( self, @@ -267,22 +270,32 @@ class Keyring: Codes.UNAUTHORIZED, ) - # If we are the originating server don't fetch verify key for self over federation - if verify_request.server_name == self.hostname: - await self._process_json(self.verify_key, verify_request) - return + found_keys: Dict[str, FetchKeyResult] = {} - # Add the keys we need to verify to the queue for retrieval. We queue - # up requests for the same server so we don't end up with many in flight - # requests for the same keys. - key_request = verify_request.to_fetch_key_request() - found_keys_by_server = await self._server_queue.add_to_queue( - key_request, key=verify_request.server_name - ) + # If we are the originating server, short-circuit the key-fetch for any keys + # we already have + if verify_request.server_name == self._hostname: + for key_id in verify_request.key_ids: + if key_id in self._local_verify_keys: + found_keys[key_id] = self._local_verify_keys[key_id] + + key_ids_to_find = set(verify_request.key_ids) - found_keys.keys() + if key_ids_to_find: + # Add the keys we need to verify to the queue for retrieval. We queue + # up requests for the same server so we don't end up with many in flight + # requests for the same keys. + key_request = _FetchKeyRequest( + server_name=verify_request.server_name, + minimum_valid_until_ts=verify_request.minimum_valid_until_ts, + key_ids=list(key_ids_to_find), + ) + found_keys_by_server = await self._server_queue.add_to_queue( + key_request, key=verify_request.server_name + ) - # Since we batch up requests the returned set of keys may contain keys - # from other servers, so we pull out only the ones we care about.s - found_keys = found_keys_by_server.get(verify_request.server_name, {}) + # Since we batch up requests the returned set of keys may contain keys + # from other servers, so we pull out only the ones we care about. + found_keys.update(found_keys_by_server.get(verify_request.server_name, {})) # Verify each signature we got valid keys for, raising if we can't # verify any of them. @@ -654,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): perspective_name, ) + request: JsonDict = {} + for queue_value in keys_to_fetch: + # there may be multiple requests for each server, so we have to merge + # them intelligently. + request_for_server = { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids + } + request.setdefault(queue_value.server_name, {}).update(request_for_server) + + logger.debug("Request to notary server %s: %s", perspective_name, request) + try: query_response = await self.client.post_json( destination=perspective_name, path="/_matrix/key/v2/query", - data={ - "server_keys": { - queue_value.server_name: { - key_id: { - "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, - } - for key_id in queue_value.key_ids - } - for queue_value in keys_to_fetch - } - }, + data={"server_keys": request}, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -676,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) + logger.debug( + "Response from notary server %s: %s", perspective_name, query_response + ) + keys: Dict[str, Dict[str, FetchKeyResult]] = {} added_keys: List[Tuple[str, str, FetchKeyResult]] = [] diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 4f409f31e1..eb39e0ae32 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -128,14 +128,12 @@ class EventBuilder: ) format_version = self.room_version.event_format + # The types of auth/prev events changes between event versions. + prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.V1: - # The types of auth/prev events changes between event versions. - auth_events: Union[ - List[str], List[Tuple[str, Dict[str, str]]] - ] = await self._store.add_event_hashes(auth_event_ids) - prev_events: Union[ - List[str], List[Tuple[str, Dict[str, str]]] - ] = await self._store.add_event_hashes(prev_event_ids) + auth_events = await self._store.add_event_hashes(auth_event_ids) + prev_events = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_event_ids prev_events = prev_event_ids diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index d7527008c4..f251402ed8 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -322,6 +322,11 @@ class _AsyncEventContextImpl(EventContext): attributes by loading from the database. """ if self.state_group is None: + # No state group means the event is an outlier. Usually the state_ids dicts are also + # pre-set to empty dicts, but they get reset when the context is serialized, so set + # them to empty dicts again here. + self._current_state_ids = {} + self._prev_state_ids = {} return current_state_ids = await self._storage.state.get_state_ids_for_group( diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 6fa631aa1d..84ef69df67 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -305,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, + *, as_client_event: bool = True, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, token_id: Optional[str] = None, @@ -392,15 +394,18 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, + *, bundle_aggregations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: - event + event: The event being serialized. time_now: The current time in milliseconds - bundle_aggregations: Whether to bundle in related events + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) **kwargs: Arguments to pass to `serialize_event` Returns: @@ -410,76 +415,109 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - event_id = event.event_id serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if there are any relations - # we need to bundle in with the event. - # Do not bundle relations if the event has been redacted - if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_aggregations + # Check if there are any bundled aggregations to include with the event. + # + # Do not bundle aggregations if any of the following at true: + # + # * Support is disabled via the configuration or the caller. + # * The event is a state event. + # * The event has been redacted. + if ( + self._msc1849_enabled + and bundle_aggregations + and not event.is_state() + and not event.internal_metadata.is_redacted() ): - annotations = await self.store.get_aggregation_groups_for_event(event_id) - references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" - ) + await self._injected_bundled_aggregations(event, time_now, serialized_event) - if annotations.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.ANNOTATION] = annotations.to_dict() - - if references.chunk: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) - - if edit: - # If there is an edit replace the content, preserving existing - # relations. - - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relations = event.content.get("m.relates_to") - if relations: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relations.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.REPLACE] = { - "event_id": edit.event_id, - "origin_server_ts": edit.origin_server_ts, - "sender": edit.sender, - } + return serialized_event - # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id) - if latest_thread_event: - r = serialized_event["unsigned"].setdefault("m.relations", {}) - r[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } + async def _injected_bundled_aggregations( + self, event: EventBase, time_now: int, serialized_event: JsonDict + ) -> None: + """Potentially injects bundled aggregations into the unsigned portion of the serialized event. - return serialized_event + Args: + event: The event being serialized. + time_now: The current time in milliseconds + serialized_event: The serialized event which may be modified. + + """ + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return + + event_id = event.event_id + + # The bundled aggregations to include. + aggregations = {} + + annotations = await self.store.get_aggregation_groups_for_event(event_id) + if annotations.chunk: + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.store.get_applicable_edit(event_id) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + + aggregations[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + "origin_server_ts": edit.origin_server_ts, + "sender": edit.sender, + } + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=False + ), + "count": thread_count, + } + + # If any bundled aggregations were found, include them. + if aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + aggregations + ) async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 3b85b135e0..fee1477ab6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -128,7 +128,7 @@ class FederationClient(FederationBase): reset_expiry_on_get=False, ) - def _clear_tried_cache(self): + def _clear_tried_cache(self) -> None: """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -800,7 +800,7 @@ class FederationClient(FederationBase): no servers successfully handle the request. """ - async def send_request(destination) -> SendJoinResult: + async def send_request(destination: str) -> SendJoinResult: response = await self._do_send_join(room_version, destination, pdu) # If an event was returned (and expected to be returned): @@ -1395,11 +1395,28 @@ class FederationClient(FederationBase): async def send_request( destination: str, ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: - res = await self.transport_layer.get_room_hierarchy( - destination=destination, - room_id=room_id, - suggested_only=suggested_only, - ) + try: + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the unstable endpoint. Otherwise consider it a + # legitmate error and raise. + if not self._is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API" + ) + + res = await self.transport_layer.get_room_hierarchy_unstable( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) room = res.get("room") if not isinstance(room, dict): @@ -1449,6 +1466,10 @@ class FederationClient(FederationBase): if e.code != 502: raise + logger.debug( + "Couldn't fetch room hierarchy, falling back to the spaces API" + ) + # Fallback to the old federation API and translate the results if # no servers implement the new API. # @@ -1496,6 +1517,83 @@ class FederationClient(FederationBase): self._get_room_hierarchy_cache[(room_id, suggested_only)] = result return result + async def timestamp_to_event( + self, destination: str, room_id: str, timestamp: int, direction: str + ) -> "TimestampToEventResponse": + """ + Calls a remote federating server at `destination` asking for their + closest event to the given timestamp in the given direction. Also + validates the response to always return the expected keys or raises an + error. + + Args: + destination: Domain name of the remote homeserver + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A parsed TimestampToEventResponse including the closest event_id + and origin_server_ts + + Raises: + Various exceptions when the request fails + InvalidResponseError when the response does not have the correct + keys or wrong types + """ + remote_response = await self.transport_layer.timestamp_to_event( + destination, room_id, timestamp, direction + ) + + if not isinstance(remote_response, dict): + raise InvalidResponseError( + "Response must be a JSON dictionary but received %r" % remote_response + ) + + try: + return TimestampToEventResponse.from_json_dict(remote_response) + except ValueError as e: + raise InvalidResponseError(str(e)) + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class TimestampToEventResponse: + """Typed response dictionary for the federation /timestamp_to_event endpoint""" + + event_id: str + origin_server_ts: int + + # the raw data, including the above keys + data: JsonDict + + @classmethod + def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse": + """Parsed response from the federation /timestamp_to_event endpoint + + Args: + d: JSON object response to be parsed + + Raises: + ValueError if d does not the correct keys or they are the wrong types + """ + + event_id = d.get("event_id") + if not isinstance(event_id, str): + raise ValueError( + "Invalid response: 'event_id' must be a str but received %r" % event_id + ) + + origin_server_ts = d.get("origin_server_ts") + if not isinstance(origin_server_ts, int): + raise ValueError( + "Invalid response: 'origin_server_ts' must be a int but received %r" + % origin_server_ts + ) + + return cls(event_id, origin_server_ts, d) + @attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9a8758e9a6..8e37e76206 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 Matrix.org Federation C.I.C +# Copyright 2019-2021 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -110,6 +110,7 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() + self.storage = hs.get_storage() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -200,6 +201,48 @@ class FederationServer(FederationBase): return 200, res + async def on_timestamp_to_event_request( + self, origin: str, room_id: str, timestamp: int, direction: str + ) -> Tuple[int, Dict[str, Any]]: + """When we receive a federated `/timestamp_to_event` request, + handle all of the logic for validating and fetching the event. + + Args: + origin: The server we received the event from + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + Tuple indicating the response status code and dictionary response + body including `event_id`. + """ + with (await self._server_linearizer.queue((origin, room_id))): + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + # We only try to fetch data from the local database + event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + if event_id: + event = await self.store.get_event( + event_id, allow_none=False, allow_rejected=False + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": event.origin_server_ts, + } + + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + async def on_incoming_transaction( self, origin: str, @@ -407,7 +450,7 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - async def process_pdus_for_room(room_id: str): + async def process_pdus_for_room(room_id: str) -> None: with nested_logging_context(room_id): logger.debug("Processing PDUs for %s", room_id) @@ -504,7 +547,7 @@ class FederationServer(FederationBase): async def on_state_ids_request( self, origin: str, room_id: str, event_id: str - ) -> Tuple[int, Dict[str, Any]]: + ) -> Tuple[int, JsonDict]: if not event_id: raise NotImplementedError("Specify an event") @@ -524,7 +567,9 @@ class FederationServer(FederationBase): return 200, resp - async def _on_state_ids_request_compute(self, room_id, event_id): + async def _on_state_ids_request_compute( + self, room_id: str, event_id: str + ) -> JsonDict: state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} @@ -613,8 +658,11 @@ class FederationServer(FederationBase): state = await self.store.get_events(state_ids) time_now = self._clock.time_msec() + event_json = event.get_pdu_json() return { - "org.matrix.msc3083.v2.event": event.get_pdu_json(), + # TODO Remove the unstable prefix when servers have updated. + "org.matrix.msc3083.v2.event": event_json, + "event": event_json, "state": [p.get_pdu_json(time_now) for p in state.values()], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 4fead6ca29..523ab1c51e 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +24,7 @@ from typing import Optional, Tuple from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.storage.databases.main import DataStore from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -31,7 +33,7 @@ logger = logging.getLogger(__name__) class TransactionActions: """Defines persistence actions that relate to handling Transactions.""" - def __init__(self, datastore): + def __init__(self, datastore: DataStore): self.store = datastore @log_function diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 1fbf325fdc..63289a5a33 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -350,7 +351,7 @@ class BaseFederationRow: TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. @staticmethod - def from_data(data): + def from_data(data: JsonDict) -> "BaseFederationRow": """Parse the data from the federation stream into a row. Args: @@ -359,7 +360,7 @@ class BaseFederationRow: """ raise NotImplementedError() - def to_data(self): + def to_data(self) -> JsonDict: """Serialize this row to be sent over the federation stream. Returns: @@ -368,7 +369,7 @@ class BaseFederationRow: """ raise NotImplementedError() - def add_to_buffer(self, buff): + def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: """Add this row to the appropriate field in the buffer ready for this to be sent over federation. @@ -391,15 +392,15 @@ class PresenceDestinationsRow( TypeId = "pd" @staticmethod - def from_data(data): + def from_data(data: JsonDict) -> "PresenceDestinationsRow": return PresenceDestinationsRow( state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] ) - def to_data(self): + def to_data(self) -> JsonDict: return {"state": self.state.as_dict(), "dests": self.destinations} - def add_to_buffer(self, buff): + def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: buff.presence_destinations.append((self.state, self.destinations)) @@ -417,13 +418,13 @@ class KeyedEduRow( TypeId = "k" @staticmethod - def from_data(data): + def from_data(data: JsonDict) -> "KeyedEduRow": return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) - def to_data(self): + def to_data(self) -> JsonDict: return {"key": self.key, "edu": self.edu.get_internal_dict()} - def add_to_buffer(self, buff): + def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu @@ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu TypeId = "e" @staticmethod - def from_data(data): + def from_data(data: JsonDict) -> "EduRow": return EduRow(Edu(**data)) - def to_data(self): + def to_data(self) -> JsonDict: return self.edu.get_internal_dict() - def add_to_buffer(self, buff): + def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None: buff.edus.setdefault(self.edu.destination, []).append(self.edu) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index afe35e72b6..391b30fbb5 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -1,5 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +15,8 @@ # limitations under the License. import datetime import logging -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple +from types import TracebackType +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type import attr from prometheus_client import Counter @@ -213,7 +215,7 @@ class PerDestinationQueue: self._pending_edus_keyed[(edu.edu_type, key)] = edu self.attempt_new_transaction() - def send_edu(self, edu) -> None: + def send_edu(self, edu: Edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() @@ -701,7 +703,12 @@ class _TransactionQueueManager: return self._pdus, pending_edus - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: if exc_type is not None: # Failed to send transaction, so we bail out. return diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 10b5aa5af8..9fc4c31c93 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -21,6 +21,7 @@ from typing import ( Callable, Collection, Dict, + Generator, Iterable, List, Mapping, @@ -149,6 +150,42 @@ class TransportLayerClient: ) @log_function + async def timestamp_to_event( + self, destination: str, room_id: str, timestamp: int, direction: str + ) -> Union[JsonDict, List]: + """ + Calls a remote federating server at `destination` asking for their + closest event to the given timestamp in the given direction. + + Args: + destination: Domain name of the remote homeserver + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + Response dict received from the remote homeserver. + + Raises: + Various exceptions when the request fails + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, + "/org.matrix.msc3030/timestamp_to_event/%s", + room_id, + ) + + args = {"ts": [str(timestamp)], "dir": [direction]} + + remote_response = await self.client.get_json( + destination, path=path, args=args, try_trailing_slash_on_400=True + ) + + return remote_response + + @log_function async def send_transaction( self, transaction: Transaction, @@ -199,11 +236,16 @@ class TransportLayerClient: @log_function async def make_query( - self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False - ): + self, + destination: str, + query_type: str, + args: dict, + retry_on_dns_fail: bool, + ignore_backoff: bool = False, + ) -> JsonDict: path = _create_v1_path("/query/%s", query_type) - content = await self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args=args, @@ -212,8 +254,6 @@ class TransportLayerClient: ignore_backoff=ignore_backoff, ) - return content - @log_function async def make_membership_event( self, @@ -1192,10 +1232,24 @@ class TransportLayerClient: ) async def get_room_hierarchy( - self, - destination: str, - room_id: str, - suggested_only: bool, + self, destination: str, room_id: str, suggested_only: bool + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + """ + path = _create_v1_path("/hierarchy/%s", room_id) + + return await self.client.get_json( + destination=destination, + path=path, + args={"suggested_only": "true" if suggested_only else "false"}, + ) + + async def get_room_hierarchy_unstable( + self, destination: str, room_id: str, suggested_only: bool ) -> JsonDict: """ Args: @@ -1267,7 +1321,7 @@ class SendJoinResponse: @ijson.coroutine -def _event_parser(event_dict: JsonDict): +def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs to add them to a given dictionary. """ @@ -1278,7 +1332,9 @@ def _event_parser(event_dict: JsonDict): @ijson.coroutine -def _event_list_parser(room_version: RoomVersion, events: List[EventBase]): +def _event_list_parser( + room_version: RoomVersion, events: List[EventBase] +) -> Generator[None, JsonDict, None]: """Helper function for use with `ijson.items_coro` to parse an array of events and add them to the given list. """ @@ -1317,15 +1373,26 @@ class SendJoinParser(ByteParser[SendJoinResponse]): prefix + "auth_chain.item", use_float=True, ) - self._coro_event = ijson.kvitems_coro( + # TODO Remove the unstable prefix when servers have updated. + # + # By re-using the same event dictionary this will cause the parsing of + # org.matrix.msc3083.v2.event and event to stomp over each other. + # Generally this should be fine. + self._coro_unstable_event = ijson.kvitems_coro( _event_parser(self._response.event_dict), prefix + "org.matrix.msc3083.v2.event", use_float=True, ) + self._coro_event = ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "event", + use_float=True, + ) def write(self, data: bytes) -> int: self._coro_state.send(data) self._coro_auth.send(data) + self._coro_unstable_event.send(data) self._coro_event.send(data) return len(data) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index c32539bf5a..77b936361a 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -22,7 +22,10 @@ from synapse.federation.transport.server._base import ( Authenticator, BaseFederationServlet, ) -from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES +from synapse.federation.transport.server.federation import ( + FEDERATION_SERVLET_CLASSES, + FederationTimestampLookupServlet, +) from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES from synapse.federation.transport.server.groups_server import ( GROUP_SERVER_SERVLET_CLASSES, @@ -299,7 +302,7 @@ def register_servlets( authenticator: Authenticator, ratelimiter: FederationRateLimiter, servlet_groups: Optional[Iterable[str]] = None, -): +) -> None: """Initialize and register servlet classes. Will by default register all servlets. For custom behaviour, pass in @@ -324,6 +327,13 @@ def register_servlets( ) for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled + if ( + servletclass == FederationTimestampLookupServlet + and not hs.config.experimental.msc3030_enabled + ): + continue + servletclass( hs=hs, authenticator=authenticator, diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index cef65929c5..dc39e3537b 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -15,10 +15,13 @@ import functools import logging import re +from typing import Any, Awaitable, Callable, Optional, Tuple, cast from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX +from synapse.http.server import HttpServer, ServletCallback from synapse.http.servlet import parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -29,6 +32,7 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import parse_and_validate_server_name @@ -59,9 +63,11 @@ class Authenticator: self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets - async def authenticate_request(self, request, content): + async def authenticate_request( + self, request: SynapseRequest, content: Optional[JsonDict] + ) -> str: now = self._clock.time_msec() - json_request = { + json_request: JsonDict = { "method": request.method.decode("ascii"), "uri": request.uri.decode("ascii"), "destination": self.server_name, @@ -114,7 +120,7 @@ class Authenticator: return origin - async def _reset_retry_timings(self, origin): + async def _reset_retry_timings(self, origin: str) -> None: try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) @@ -133,14 +139,14 @@ class Authenticator: logger.exception("Error resetting retry timings on %s", origin) -def _parse_auth_header(header_bytes): +def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]: """Parse an X-Matrix auth header Args: - header_bytes (bytes): header value + header_bytes: header value Returns: - Tuple[str, str, str]: origin, key id, signature. + origin, key id, signature. Raises: AuthenticationError if the header could not be parsed @@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes): try: header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) + param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)} - def strip_quotes(value): + def strip_quotes(value: str) -> str: if value.startswith('"'): return value[1:-1] else: @@ -233,23 +239,25 @@ class BaseFederationServlet: self.ratelimiter = ratelimiter self.server_name = server_name - def _wrap(self, func): + def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback: authenticator = self.authenticator ratelimiter = self.ratelimiter @functools.wraps(func) - async def new_func(request, *args, **kwargs): + async def new_func( + request: SynapseRequest, *args: Any, **kwargs: str + ) -> Optional[Tuple[int, Any]]: """A callback which can be passed to HttpServer.RegisterPaths Args: - request (twisted.web.http.Request): + request: *args: unused? - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. + **kwargs: the dict mapping keys to path components as specified + in the path match regexp. Returns: - Tuple[int, object]|None: (response code, response object) as returned by - the callback method. None if the request has already been handled. + (response code, response object) as returned by the callback method. + None if the request has already been handled. """ content = None if request.method in [b"PUT", b"POST"]: @@ -257,7 +265,9 @@ class BaseFederationServlet: content = parse_json_object_from_request(request) try: - origin = await authenticator.authenticate_request(request, content) + origin: Optional[str] = await authenticator.authenticate_request( + request, content + ) except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: @@ -301,7 +311,7 @@ class BaseFederationServlet: "client disconnected before we started processing " "request" ) - return -1, None + return None response = await func( origin, content, request.args, *args, **kwargs ) @@ -312,9 +322,9 @@ class BaseFederationServlet: return response - return new_func + return cast(ServletCallback, new_func) - def register(self, server): + def register(self, server: HttpServer) -> None: pattern = re.compile("^" + self.PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 2fdf6cc99e..77bfd88ad0 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -174,6 +174,46 @@ class FederationBackfillServlet(BaseFederationServerServlet): return await self.handler.on_backfill_request(origin, room_id, versions, limit) +class FederationTimestampLookupServlet(BaseFederationServerServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for other homeservers when they're unable to find an event locally. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/<roomID>?ts=<timestamp>&dir=<direction> + { + "event_id": ... + } + """ + + PATH = "/timestamp_to_event/(?P<room_id>[^/]*)/?" + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + timestamp = parse_integer_from_args(query, "ts", required=True) + direction = parse_string_from_args( + query, "dir", default="f", allowed_values=["f", "b"], required=True + ) + + return await self.handler.on_timestamp_to_event_request( + origin, room_id, timestamp, direction + ) + + class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P<query_type>[^/]*)" @@ -611,7 +651,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): class FederationRoomHierarchyServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/hierarchy/(?P<room_id>[^/]*)" def __init__( @@ -637,6 +676,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): ) +class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + + class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -680,6 +723,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, + FederationTimestampLookupServlet, FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, @@ -701,6 +745,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( RoomComplexityServlet, FederationSpaceSummaryServlet, FederationRoomHierarchyServlet, + FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, ) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 53f99031b1..a87896e538 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple from signedjson.sign import sign_json +from twisted.internet.defer import Deferred + from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, get_domain_from_id @@ -166,7 +168,7 @@ class GroupAttestionRenewer: return {} - def _start_renew_attestations(self) -> None: + def _start_renew_attestations(self) -> "Deferred[None]": return run_as_background_process("renew_attestations", self._renew_attestations) async def _renew_attestations(self) -> None: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index be3203ac80..85157a138b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -234,7 +234,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_invite( - self, room_id: str, event: EventBase, state: StateMap[dict] + self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: """Write an invite for the room, with associated invite state. @@ -248,7 +248,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_knock( - self, room_id: str, event: EventBase, state: StateMap[dict] + self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: """Write a knock for the room, with associated knock state. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 60e59d11a0..61607cf2ba 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,6 +18,7 @@ import time import unicodedata import urllib.parse from binascii import crc32 +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -38,6 +39,7 @@ import attr import bcrypt import pymacaroons import unpaddedbase64 +from pymacaroons.exceptions import MacaroonVerificationFailedException from twisted.web.server import Request @@ -181,8 +183,11 @@ class LoginTokenAttributes: user_id = attr.ib(type=str) - # the SSO Identity Provider that the user authenticated with, to get this token auth_provider_id = attr.ib(type=str) + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id = attr.ib(type=Optional[str]) + """The session ID advertised by the SSO Identity Provider.""" class AuthHandler: @@ -756,53 +761,109 @@ class AuthHandler: async def refresh_token( self, refresh_token: str, - valid_until_ms: Optional[int], - ) -> Tuple[str, str]: + access_token_valid_until_ms: Optional[int], + refresh_token_valid_until_ms: Optional[int], + ) -> Tuple[str, str, Optional[int]]: """ Consumes a refresh token and generate both a new access token and a new refresh token from it. The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. + The lifetime of both the access token and refresh token will be capped so that they + do not exceed the session's ultimate expiry time, if applicable. + Args: refresh_token: The token to consume. - valid_until_ms: The expiration timestamp of the new access token. - + access_token_valid_until_ms: The expiration timestamp of the new access token. + None if the access token does not expire. + refresh_token_valid_until_ms: The expiration timestamp of the new refresh token. + None if the refresh token does not expire. Returns: - A tuple containing the new access token and refresh token + A tuple containing: + - the new access token + - the new refresh token + - the actual expiry time of the access token, which may be earlier than + `access_token_valid_until_ms`. """ # Verify the token signature first before looking up the token if not self._verify_refresh_token(refresh_token): - raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) + raise SynapseError( + HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN + ) existing_token = await self.store.lookup_refresh_token(refresh_token) if existing_token is None: - raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) + raise SynapseError( + HTTPStatus.UNAUTHORIZED, + "refresh token does not exist", + Codes.UNKNOWN_TOKEN, + ) if ( existing_token.has_next_access_token_been_used or existing_token.has_next_refresh_token_been_refreshed ): raise SynapseError( - 403, "refresh token isn't valid anymore", Codes.FORBIDDEN + HTTPStatus.FORBIDDEN, + "refresh token isn't valid anymore", + Codes.FORBIDDEN, ) + now_ms = self._clock.time_msec() + + if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: + + raise SynapseError( + HTTPStatus.FORBIDDEN, + "The supplied refresh token has expired", + Codes.FORBIDDEN, + ) + + if existing_token.ultimate_session_expiry_ts is not None: + # This session has a bounded lifetime, even across refreshes. + + if access_token_valid_until_ms is not None: + access_token_valid_until_ms = min( + access_token_valid_until_ms, + existing_token.ultimate_session_expiry_ts, + ) + else: + access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts + + if refresh_token_valid_until_ms is not None: + refresh_token_valid_until_ms = min( + refresh_token_valid_until_ms, + existing_token.ultimate_session_expiry_ts, + ) + else: + refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts + if existing_token.ultimate_session_expiry_ts < now_ms: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "The session has expired and can no longer be refreshed", + Codes.FORBIDDEN, + ) + ( new_refresh_token, new_refresh_token_id, - ) = await self.get_refresh_token_for_user_id( - user_id=existing_token.user_id, device_id=existing_token.device_id + ) = await self.create_refresh_token_for_user_id( + user_id=existing_token.user_id, + device_id=existing_token.device_id, + expiry_ts=refresh_token_valid_until_ms, + ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts, ) - access_token = await self.get_access_token_for_user_id( + access_token = await self.create_access_token_for_user_id( user_id=existing_token.user_id, device_id=existing_token.device_id, - valid_until_ms=valid_until_ms, + valid_until_ms=access_token_valid_until_ms, refresh_token_id=new_refresh_token_id, ) await self.store.replace_refresh_token( existing_token.token_id, new_refresh_token_id ) - return access_token, new_refresh_token + return access_token, new_refresh_token, access_token_valid_until_ms def _verify_refresh_token(self, token: str) -> bool: """ @@ -832,10 +893,12 @@ class AuthHandler: return True - async def get_refresh_token_for_user_id( + async def create_refresh_token_for_user_id( self, user_id: str, device_id: str, + expiry_ts: Optional[int], + ultimate_session_expiry_ts: Optional[int], ) -> Tuple[str, int]: """ Creates a new refresh token for the user with the given user ID. @@ -843,6 +906,13 @@ class AuthHandler: Args: user_id: canonical user ID device_id: the device ID to associate with the token. + expiry_ts (milliseconds since the epoch): Time after which the + refresh token cannot be used. + If None, the refresh token never expires until it has been used. + ultimate_session_expiry_ts (milliseconds since the epoch): + Time at which the session will end and can not be extended any + further. + If None, the session can be refreshed indefinitely. Returns: The newly created refresh token and its ID in the database @@ -852,10 +922,12 @@ class AuthHandler: user_id=user_id, token=refresh_token, device_id=device_id, + expiry_ts=expiry_ts, + ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) return refresh_token, refresh_token_id - async def get_access_token_for_user_id( + async def create_access_token_for_user_id( self, user_id: str, device_id: Optional[str], @@ -1582,6 +1654,7 @@ class AuthHandler: client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1597,6 +1670,7 @@ class AuthHandler: during successful login. Must be JSON serializable. new_user: True if we should use wording appropriate to a user who has just registered. + auth_provider_session_id: The session ID from the SSO IdP received during login. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1617,6 +1691,7 @@ class AuthHandler: extra_attributes, new_user=new_user, user_profile_data=profile, + auth_provider_session_id=auth_provider_session_id, ) def _complete_sso_login( @@ -1628,6 +1703,7 @@ class AuthHandler: extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ The synchronous portion of complete_sso_login. @@ -1649,7 +1725,9 @@ class AuthHandler: # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id, auth_provider_id=auth_provider_id + registered_user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) # Append the login token to the original redirect URL (i.e. with its query @@ -1754,6 +1832,7 @@ class MacaroonGenerator: self, user_id: str, auth_provider_id: str, + auth_provider_session_id: Optional[str] = None, duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1762,6 +1841,10 @@ class MacaroonGenerator: expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) + if auth_provider_session_id is not None: + macaroon.add_first_party_caveat( + "auth_provider_session_id = %s" % (auth_provider_session_id,) + ) return macaroon.serialize() def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: @@ -1783,15 +1866,28 @@ class MacaroonGenerator: user_id = get_value_from_macaroon(macaroon, "user_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + auth_provider_session_id: Optional[str] = None + try: + auth_provider_session_id = get_value_from_macaroon( + macaroon, "auth_provider_session_id" + ) + except MacaroonVerificationFailedException: + pass + v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") v.satisfy_exact("type = login") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) satisfy_expiry(v, self.hs.get_clock().time_msec) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + return LoginTokenAttributes( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1828,13 +1924,6 @@ def load_single_legacy_password_auth_provider( logger.error("Error while initializing %r: %s", module, e) raise - # The known hooks. If a module implements a method who's name appears in this set - # we'll want to register it - password_auth_provider_methods = { - "check_3pid_auth", - "on_logged_out", - } - # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: @@ -1919,11 +2008,14 @@ def load_single_legacy_password_auth_provider( return run - # populate hooks with the implemented methods, wrapped with async_wrapper - hooks = { - hook: async_wrapper(getattr(provider, hook, None)) - for hook in password_auth_provider_methods - } + # If the module has these methods implemented, then we pull them out + # and register them as hooks. + check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper( + getattr(provider, "check_3pid_auth", None) + ) + on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper( + getattr(provider, "on_logged_out", None) + ) supported_login_types = {} # call get_supported_login_types and add that to the dict @@ -1950,7 +2042,11 @@ def load_single_legacy_password_auth_provider( # need to use a tuple here for ("password",) not a list since lists aren't hashable auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password - api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + api.register_password_auth_provider_callbacks( + check_3pid_auth=check_3pid_auth_hook, + on_logged_out=on_logged_out_hook, + auth_checkers=auth_checkers, + ) CHECK_3PID_AUTH_CALLBACK = Callable[ diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 68b446eb66..82ee11e921 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: str, device_id: Optional[str], initial_device_display_name: Optional[str] = None, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> str: """ If the given device has not been registered, register it with the @@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: @user:id device_id: device id supplied by client initial_device_display_name: device display name from client + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from the SSO IdP. Returns: device id (generated if none was supplied) """ @@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [device_id]) @@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=new_device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [new_device_id]) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8ca5f60b1c..7ee5c47fd9 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -204,6 +204,10 @@ class DirectoryHandler: ) room_id = await self._delete_association(room_alias) + if room_id is None: + # It's possible someone else deleted the association after the + # checks above, but before we did the deletion. + raise NotFoundError("Unknown room alias") try: await self._update_canonical_alias(requester, user_id, room_id, room_alias) @@ -225,7 +229,7 @@ class DirectoryHandler: ) await self._delete_association(room_alias) - async def _delete_association(self, room_alias: RoomAlias) -> str: + async def _delete_association(self, room_alias: RoomAlias) -> Optional[str]: if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 60c11e3d21..b2554bda04 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -65,8 +65,12 @@ class E2eKeysHandler: else: # Only register this edu handler on master as it requires writing # device updates to the db - # - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "m.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # also handle the unstable version + # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1f64534a8a..32b0254c5f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -122,8 +122,7 @@ class EventStreamHandler: events, time_now, as_client_event=as_client_event, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. + # Don't bundle aggregations as this is a deprecated API. bundle_aggregations=False, ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3f5d1d701f..b8990da954 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -68,6 +68,37 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: + """Get joined domains from state + + Args: + state: State map from type/state key to event. + + Returns: + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. + """ + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member and event.membership == Membership.JOIN + ] + + joined_domains: Dict[str, int] = {} + for u, d in joined_users: + try: + dom = get_domain_from_id(u) + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except Exception: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + class FederationHandler: """Handles general incoming federation requests @@ -274,36 +305,6 @@ class FederationHandler: curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) - curr_domains = get_domains_from_state(curr_state) likely_domains = [ diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index afc1de9894..2a0fb9dbdf 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -998,8 +998,6 @@ class FederationEventHandler: origin, event, context, - state=state, - backfilled=backfilled, ) except AuthError as e: # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it @@ -1356,8 +1354,6 @@ class FederationEventHandler: origin: str, event: EventBase, context: EventContext, - state: Optional[Iterable[EventBase]] = None, - backfilled: bool = False, ) -> EventContext: """ Checks whether an event should be rejected (for failing auth checks). @@ -1368,12 +1364,6 @@ class FederationEventHandler: context: The event context. - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - backfilled: True if the event was backfilled. - Returns: The updated context object. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 3dbe611f95..c83eaea359 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -464,15 +464,6 @@ class IdentityHandler: if next_link: params["next_link"] = next_link - if self.hs.config.email.using_identity_server_from_trusted_list: - # Warn that a deprecated config option is in use - logger.warning( - 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the sample config at docs/sample_config.yaml for " - "details and update your config file." - ) - try: data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/email/requestToken", @@ -517,15 +508,6 @@ class IdentityHandler: if next_link: params["next_link"] = next_link - if self.hs.config.email.using_identity_server_from_trusted_list: - # Warn that a deprecated config option is in use - logger.warning( - 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the sample config at docs/sample_config.yaml for " - "details and update your config file." - ) - try: data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d4e4556155..9cd21e7f2b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -165,7 +165,11 @@ class InitialSyncHandler: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, time_now, as_client_event + invite_event, + time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) rooms_ret.append(d) @@ -216,7 +220,11 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, time_now=time_now, as_client_event=as_client_event + messages, + time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) ), "start": await start_token.to_string(self.store), @@ -226,6 +234,8 @@ class InitialSyncHandler: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, as_client_event=as_client_event, ) @@ -366,14 +376,18 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( + # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now + room_state.values(), time_now, bundle_aggregations=False ) ), "presence": [], @@ -392,8 +406,9 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), time_now, bundle_aggregations=False ) now_token = self.hs.get_event_sources().get_current_token() @@ -467,7 +482,10 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 916ba0662d..d370b2c302 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -247,13 +247,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events( - room_state.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_aggregations=False, - ) + events = await self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: @@ -1011,13 +1005,52 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) + await self._validate_event_relation(event) + logger.debug("Created event %s", event.event_id) + + return event, context + + async def _validate_event_relation(self, event: EventBase) -> None: + """ + Ensure the relation data on a new event is not bogus. + + Args: + event: The event being created. + + Raises: + SynapseError if the event is invalid. + """ + + relation = event.content.get("m.relates_to") + if not relation: + return + + relation_type = relation.get("rel_type") + if not relation_type: + return + + # Ensure the parent is real. + relates_to = relation.get("event_id") + if not relates_to: + return + + parent_event = await self.store.get_event(relates_to, allow_none=True) + if parent_event: + # And in the same room. + if parent_event.room_id != event.room_id: + raise SynapseError(400, "Relations must be in the same room") + + else: + # There must be some reason that the client knows the event exists, + # see if there are existing relations. If so, assume everything is fine. + if not await self.store.event_is_target_of_relation(relates_to): + # Otherwise, the client can't know about the parent event! + raise SynapseError(400, "Can't send relation to unknown event") # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an # event multiple times). - relation = event.content.get("m.relates_to", {}) - if relation.get("rel_type") == RelationTypes.ANNOTATION: - relates_to = relation["event_id"] + if relation_type == RelationTypes.ANNOTATION: aggregation_key = relation["key"] already_exists = await self.store.has_user_annotated_event( @@ -1026,9 +1059,12 @@ class EventCreationHandler: if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug("Created event %s", event.event_id) - - return event, context + # Don't attempt to start a thread if the parent event is a relation. + elif relation_type == RelationTypes.THREAD: + if await self.store.event_includes_relation(relates_to): + raise SynapseError( + 400, "Cannot start threads from an event with a relation" + ) @measure_func("handle_new_client_event") async def handle_new_client_event( diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 3665d91513..deb3539751 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -23,7 +23,7 @@ from authlib.common.security import generate_token from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri -from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( @@ -117,7 +117,8 @@ class OidcHandler: for idp_id, p in self._providers.items(): try: await p.load_metadata() - await p.load_jwks() + if not p._uses_userinfo: + await p.load_jwks() except Exception as e: raise Exception( "Error while initialising OIDC provider %r" % (idp_id,) @@ -498,10 +499,6 @@ class OidcProvider: return await self._jwks.get() async def _load_jwks(self) -> JWKS: - if self._uses_userinfo: - # We're not using jwt signing, return an empty jwk set - return {"keys": []} - metadata = await self.load_metadata() # Load the JWKS using the `jwks_uri` metadata. @@ -663,7 +660,7 @@ class OidcProvider: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. Args: @@ -673,7 +670,7 @@ class OidcProvider: request. This value should match the one inside the token. Returns: - An object representing the user. + The decoded claims in the ID token. """ metadata = await self.load_metadata() claims_params = { @@ -684,9 +681,6 @@ class OidcProvider: # If we got an `access_token`, there should be an `at_hash` claim # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) @@ -703,7 +697,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -713,7 +707,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -721,7 +715,8 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) claims.validate(leeway=120) # allows 2 min of clock skew - return UserInfo(claims) + + return claims async def handle_redirect_request( self, @@ -837,8 +832,22 @@ class OidcProvider: logger.debug("Successfully obtained OAuth2 token data: %r", token) - # Now that we have a token, get the userinfo, either by decoding the - # `id_token` or by fetching the `userinfo_endpoint`. + # If there is an id_token, it should be validated, regardless of the + # userinfo endpoint is used or not. + if token.get("id_token") is not None: + try: + id_token = await self._parse_id_token(token, nonce=session_data.nonce) + sid = id_token.get("sid") + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return + else: + id_token = None + sid = None + + # Now that we have a token, get the userinfo either from the `id_token` + # claims or by fetching the `userinfo_endpoint`. if self._uses_userinfo: try: userinfo = await self._fetch_userinfo(token) @@ -846,13 +855,14 @@ class OidcProvider: logger.exception("Could not fetch userinfo") self._sso_handler.render_error(request, "fetch_error", str(e)) return + elif id_token is not None: + userinfo = UserInfo(id_token) else: - try: - userinfo = await self._parse_id_token(token, nonce=session_data.nonce) - except Exception as e: - logger.exception("Invalid id_token") - self._sso_handler.render_error(request, "invalid_token", str(e)) - return + logger.error("Missing id_token in token response") + self._sso_handler.render_error( + request, "invalid_token", "Missing id_token in token response" + ) + return # first check if we're doing a UIA if session_data.ui_auth_session_id: @@ -884,7 +894,7 @@ class OidcProvider: # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, session_data.client_redirect_url + userinfo, token, request, session_data.client_redirect_url, sid ) except MappingException as e: logger.exception("Could not map user") @@ -896,6 +906,7 @@ class OidcProvider: token: Token, request: SynapseRequest, client_redirect_url: str, + sid: Optional[str], ) -> None: """Given a UserInfo response, complete the login flow @@ -1008,6 +1019,7 @@ class OidcProvider: oidc_response_to_user_attributes, grandfather_existing_users, extra_attributes, + auth_provider_session_id=sid, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index aa26911aed..4f42438053 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set import attr @@ -22,7 +22,7 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.logging.context import run_in_background +from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig @@ -56,11 +56,62 @@ class PurgeStatus: STATUS_FAILED: "failed", } + # Save the error message if an error occurs + error: str = "" + # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}. status: int = STATUS_ACTIVE def asdict(self) -> JsonDict: - return {"status": PurgeStatus.STATUS_TEXT[self.status]} + ret = {"status": PurgeStatus.STATUS_TEXT[self.status]} + if self.error: + ret["error"] = self.error + return ret + + +@attr.s(slots=True, auto_attribs=True) +class DeleteStatus: + """Object tracking the status of a delete room request + + This class contains information on the progress of a delete room request, for + return by get_delete_status. + """ + + STATUS_PURGING = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + STATUS_SHUTTING_DOWN = 3 + + STATUS_TEXT = { + STATUS_PURGING: "purging", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + STATUS_SHUTTING_DOWN: "shutting_down", + } + + # Tracks whether this request has completed. + # One of STATUS_{PURGING,COMPLETE,FAILED,SHUTTING_DOWN}. + status: int = STATUS_PURGING + + # Save the error message if an error occurs + error: str = "" + + # Saves the result of an action to give it back to REST API + shutdown_room: ShutdownRoomResponse = { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + + def asdict(self) -> JsonDict: + ret = { + "status": DeleteStatus.STATUS_TEXT[self.status], + "shutdown_room": self.shutdown_room, + } + if self.error: + ret["error"] = self.error + return ret class PaginationHandler: @@ -70,6 +121,9 @@ class PaginationHandler: paginating during a purge. """ + # when to remove a completed deletion/purge from the results map + CLEAR_PURGE_AFTER_MS = 1000 * 3600 * 24 # 24 hours + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() @@ -78,11 +132,18 @@ class PaginationHandler: self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname + self._room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_lock = ReadWriteLock() + # IDs of rooms in which there currently an active purge *or delete* operation. self._purges_in_progress_by_room: Set[str] = set() # map from purge id to PurgeStatus self._purges_by_id: Dict[str, PurgeStatus] = {} + # map from purge id to DeleteStatus + self._delete_by_id: Dict[str, DeleteStatus] = {} + # map from room id to delete ids + # Dict[`room_id`, List[`delete_id`]] + self._delete_by_room: Dict[str, List[str]] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( @@ -265,8 +326,13 @@ class PaginationHandler: logger.info("[purge] starting purge_id %s", purge_id) self._purges_by_id[purge_id] = PurgeStatus() - run_in_background( - self._purge_history, purge_id, room_id, token, delete_local_events + run_as_background_process( + "purge_history", + self._purge_history, + purge_id, + room_id, + token, + delete_local_events, ) return purge_id @@ -276,7 +342,7 @@ class PaginationHandler: """Carry out a history purge on a room. Args: - purge_id: The id for this purge + purge_id: The ID for this purge. room_id: The room to purge from token: topological token to delete events before delete_local_events: True to delete local events as well as remote ones @@ -295,6 +361,7 @@ class PaginationHandler: "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + self._purges_by_id[purge_id].error = f.getErrorMessage() finally: self._purges_in_progress_by_room.discard(room_id) @@ -302,7 +369,9 @@ class PaginationHandler: def clear_purge() -> None: del self._purges_by_id[purge_id] - self.hs.get_reactor().callLater(24 * 3600, clear_purge) + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_purge + ) def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge @@ -312,17 +381,31 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) + def get_delete_status(self, delete_id: str) -> Optional[DeleteStatus]: + """Get the current status of an active deleting + + Args: + delete_id: delete_id returned by start_shutdown_and_purge_room + """ + return self._delete_by_id.get(delete_id) + + def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + """Get all active delete ids by room + + Args: + room_id: room_id that is deleted + """ + return self._delete_by_room.get(room_id) + async def purge_room(self, room_id: str, force: bool = False) -> None: """Purge the given room from the database. + This function is part the delete room v1 API. Args: room_id: room to be purged force: set true to skip checking for joined users. """ with await self.pagination_lock.write(room_id): - # check we know about the room - await self.store.get_room_version_id(room_id) - # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -472,3 +555,192 @@ class PaginationHandler: ) return chunk + + async def _shutdown_and_purge_room( + self, + delete_id: str, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> None: + """ + Shuts down and purges a room. + + See `RoomShutdownHandler.shutdown_room` for details of creation of the new room + + Args: + delete_id: The ID for this delete. + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action. Will be recorded as putting the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Saves a `RoomShutdownHandler.ShutdownRoomResponse` in `DeleteStatus`: + """ + + self._purges_in_progress_by_room.add(room_id) + try: + with await self.pagination_lock.write(room_id): + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN + self._delete_by_id[ + delete_id + ].shutdown_room = await self._room_shutdown_handler.shutdown_room( + room_id=room_id, + requester_user_id=requester_user_id, + new_room_user_id=new_room_user_id, + new_room_name=new_room_name, + message=message, + block=block, + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_PURGING + + if purge: + logger.info("starting purge room_id %s", room_id) + + # first check that we have no users in this room + if not force_purge: + joined = await self.store.is_host_joined( + room_id, self._server_name + ) + if joined: + raise SynapseError( + 400, "Users are still joined to this room" + ) + + await self.storage.purge_events.purge_room(room_id) + + logger.info("complete") + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE + except Exception: + f = Failure() + logger.error( + "failed", + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + self._delete_by_id[delete_id].status = DeleteStatus.STATUS_FAILED + self._delete_by_id[delete_id].error = f.getErrorMessage() + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the delete from the list 24 hours after it completes + def clear_delete() -> None: + del self._delete_by_id[delete_id] + self._delete_by_room[room_id].remove(delete_id) + if not self._delete_by_room[room_id]: + del self._delete_by_room[room_id] + + self.hs.get_reactor().callLater( + PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_delete + ) + + def start_shutdown_and_purge_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + purge: bool = True, + force_purge: bool = False, + ) -> str: + """Start off shut down and purge on a room. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + also if it fails to remove some users from room. + + Returns: + unique ID for this delete transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, "History purge already in progress for %s" % (room_id,) + ) + + # This check is double to `RoomShutdownHandler.shutdown_room` + # But here the requester get a direct response / error with HTTP request + # and do not have to check the purge status + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + delete_id = random_string(16) + + # we log the delete_id here so that it can be tied back to the + # request id in the log lines. + logger.info( + "starting shutdown room_id %s with delete_id %s", + room_id, + delete_id, + ) + + self._delete_by_id[delete_id] = DeleteStatus() + self._delete_by_room.setdefault(room_id, []).append(delete_id) + run_as_background_process( + "shutdown_and_purge_room", + self._shutdown_and_purge_room, + delete_id, + room_id, + requester_user_id, + new_room_user_id, + new_room_name, + message, + block, + purge, + force_purge, + ) + return delete_id diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3df872c578..454d06c973 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler): self._on_shutdown, ) - def _on_shutdown(self) -> None: + async def _on_shutdown(self) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4911a11535..5cb1ff749d 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id in content.keys(): event_content = content.get(event_id, {}) - m_read = event_content.get("m.read", {}) + m_read = event_content.get(ReceiptTypes.READ, {}) # If m_read is missing copy over the original event_content as there is nothing to process here if not m_read: @@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): # Set new users unless empty if len(new_users.keys()) > 0: - new_event["content"][event_id] = {"m.read": new_users} + new_event["content"][event_id] = {ReceiptTypes.READ: new_users} # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a0e6a01775..f08a516a75 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1,4 +1,5 @@ # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -116,7 +117,13 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.nonrefreshable_access_token_lifetime = ( + hs.config.registration.nonrefreshable_access_token_lifetime + ) + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) + self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime init_counters_for_auth_provider("") @@ -739,6 +746,7 @@ class RegistrationHandler: is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. @@ -749,9 +757,9 @@ class RegistrationHandler: device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: Whether it should also issue a refresh token + auth_provider_session_id: The session ID received during login from the SSO IdP. Returns: Tuple of device ID, access token, access token expiration time and refresh token """ @@ -762,6 +770,8 @@ class RegistrationHandler: is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) login_counter.labels( @@ -784,6 +794,8 @@ class RegistrationHandler: is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> LoginDict: """Helper for register_device @@ -791,38 +803,86 @@ class RegistrationHandler: class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app - valid_until_ms = None + now_ms = self.clock.time_msec() + access_token_expiry = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - valid_until_ms = self.clock.time_msec() + self.session_lifetime + access_token_expiry = now_ms + self.session_lifetime + + if self.nonrefreshable_access_token_lifetime is not None: + if access_token_expiry is not None: + # Don't allow the non-refreshable access token to outlive the + # session. + access_token_expiry = min( + now_ms + self.nonrefreshable_access_token_lifetime, + access_token_expiry, + ) + else: + access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime refresh_token = None refresh_token_id = None registered_device_id = await self.device_handler.check_device_registered( - user_id, device_id, initial_display_name + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if is_guest: - assert valid_until_ms is None + assert access_token_expiry is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: if should_issue_refresh_token: + # A refreshable access token lifetime must be configured + # since we're told to issue a refresh token (the caller checks + # that this value is set before setting this flag). + assert self.refreshable_access_token_lifetime is not None + + # Set the expiry time of the refreshable access token + access_token_expiry = now_ms + self.refreshable_access_token_lifetime + + # Set the refresh token expiry time (if configured) + refresh_token_expiry = None + if self.refresh_token_lifetime is not None: + refresh_token_expiry = now_ms + self.refresh_token_lifetime + + # Set an ultimate session expiry time (if configured) + ultimate_session_expiry_ts = None + if self.session_lifetime is not None: + ultimate_session_expiry_ts = now_ms + self.session_lifetime + + # Also ensure that the issued tokens don't outlive the + # session. + # (It would be weird to configure a homeserver with a shorter + # session lifetime than token lifetime, but may as well handle + # it.) + access_token_expiry = min( + access_token_expiry, ultimate_session_expiry_ts + ) + if refresh_token_expiry is not None: + refresh_token_expiry = min( + refresh_token_expiry, ultimate_session_expiry_ts + ) + ( refresh_token, refresh_token_id, - ) = await self._auth_handler.get_refresh_token_for_user_id( + ) = await self._auth_handler.create_refresh_token_for_user_id( user_id, device_id=registered_device_id, + expiry_ts=refresh_token_expiry, + ultimate_session_expiry_ts=ultimate_session_expiry_ts, ) - valid_until_ms = self.clock.time_msec() + self.access_token_lifetime - access_token = await self._auth_handler.get_access_token_for_user_id( + access_token = await self._auth_handler.create_access_token_for_user_id( user_id, device_id=registered_device_id, - valid_until_ms=valid_until_ms, + valid_until_ms=access_token_expiry, is_appservice_ghost=is_appservice_ghost, refresh_token_id=refresh_token_id, ) @@ -830,7 +890,7 @@ class RegistrationHandler: return { "device_id": registered_device_id, "access_token": access_token, - "valid_until_ms": valid_until_ms, + "valid_until_ms": access_token_expiry, "refresh_token": refresh_token, } diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 11af30eee7..ead2198e14 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -46,6 +46,7 @@ from synapse.api.constants import ( from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, LimitExceededError, NotFoundError, StoreError, @@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams import EventSource @@ -775,8 +778,11 @@ class RoomCreationHandler: raise SynapseError(403, "Room visibility value not allowed.") if is_public: + room_aliases = [] + if room_alias: + room_aliases.append(room_alias.to_string()) if not self.config.roomdirectory.is_publishing_room_allowed( - user_id, room_id, room_alias + user_id, room_id, room_aliases ): # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages @@ -1217,6 +1223,147 @@ class RoomContextHandler: return results +class TimestampLookupHandler: + def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.federation_client = hs.get_federation_client() + + async def get_event_for_timestamp( + self, + requester: Requester, + room_id: str, + timestamp: int, + direction: str, + ) -> Tuple[str, int]: + """Find the closest event to the given timestamp in the given direction. + If we can't find an event locally or the event we have locally is next to a gap, + it will ask other federated homeservers for an event. + + Args: + requester: The user making the request according to the access token + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + A tuple containing the `event_id` closest to the given timestamp in + the given direction and the `origin_server_ts`. + + Raises: + SynapseError if unable to find any event locally in the given direction + """ + + local_event_id = await self.store.get_event_id_for_timestamp( + room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s", + local_event_id, + timestamp, + ) + + # Check for gaps in the history where events could be hiding in between + # the timestamp given and the event we were able to find locally + is_event_next_to_backward_gap = False + is_event_next_to_forward_gap = False + if local_event_id: + local_event = await self.store.get_event( + local_event_id, allow_none=False, allow_rejected=False + ) + + if direction == "f": + # We only need to check for a backward gap if we're looking forwards + # to ensure there is nothing in between. + is_event_next_to_backward_gap = ( + await self.store.is_event_next_to_backward_gap(local_event) + ) + elif direction == "b": + # We only need to check for a forward gap if we're looking backwards + # to ensure there is nothing in between + is_event_next_to_forward_gap = ( + await self.store.is_event_next_to_forward_gap(local_event) + ) + + # If we found a gap, we should probably ask another homeserver first + # about more history in between + if ( + not local_event_id + or is_event_next_to_backward_gap + or is_event_next_to_forward_gap + ): + logger.debug( + "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first", + local_event_id, + timestamp, + ) + + # Find other homeservers from the given state in the room + curr_state = await self.state_handler.get_current_state(room_id) + curr_domains = get_domains_from_state(curr_state) + likely_domains = [ + domain for domain, depth in curr_domains if domain != self.server_name + ] + + # Loop through each homeserver candidate until we get a succesful response + for domain in likely_domains: + try: + remote_response = await self.federation_client.timestamp_to_event( + domain, room_id, timestamp, direction + ) + logger.debug( + "get_event_for_timestamp: response from domain(%s)=%s", + domain, + remote_response, + ) + + # TODO: Do we want to persist this as an extremity? + # TODO: I think ideally, we would try to backfill from + # this event and run this whole + # `get_event_for_timestamp` function again to make sure + # they didn't give us an event from their gappy history. + remote_event_id = remote_response.event_id + origin_server_ts = remote_response.origin_server_ts + + # Only return the remote event if it's closer than the local event + if not local_event or ( + abs(origin_server_ts - timestamp) + < abs(local_event.origin_server_ts - timestamp) + ): + return remote_event_id, origin_server_ts + except (HttpResponseException, InvalidResponseError) as ex: + # Let's not put a high priority on some other homeserver + # failing to respond or giving a random response + logger.debug( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + except Exception as ex: + # But we do want to see some exceptions in our code + logger.warning( + "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + domain, + type(ex).__name__, + ex, + ex.args, + ) + + if not local_event_id: + raise SynapseError( + 404, + "Unable to find event from %s in direction %s" % (timestamp, direction), + errcode=Codes.NOT_FOUND, + ) + + return local_event_id, local_event.origin_server_ts + + class RoomEventSource(EventSource[RoomStreamToken, EventBase]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1279,6 +1426,17 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): class ShutdownRoomResponse(TypedDict): + """ + Attributes: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + kicked_users: List[str] failed_to_kick_users: List[str] local_aliases: List[str] @@ -1286,7 +1444,6 @@ class ShutdownRoomResponse(TypedDict): class RoomShutdownHandler: - DEFAULT_MESSAGE = ( "Sharing illegal content on this server is not permitted and rooms in" " violation will be blocked." @@ -1299,7 +1456,6 @@ class RoomShutdownHandler: self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.state = hs.get_state_handler() self.store = hs.get_datastore() async def shutdown_room( @@ -1379,20 +1535,13 @@ class RoomShutdownHandler: await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): - if block: - # We allow you to block an unknown room. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } - else: - # But if you don't want to preventatively block another room, - # this function can't do anything useful. - raise NotFoundError( - "Cannot shut down room: unknown room id %s" % (room_id,) - ) + # if we don't know about the room, there is nothing left to do. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 2e31532389..c4d22ad297 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -217,6 +217,7 @@ class RoomBatchHandler: action=membership, content=event_dict["content"], outlier=True, + historical=True, prev_event_ids=prev_event_ids_for_state_chain, # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same @@ -236,6 +237,7 @@ class RoomBatchHandler: ), event_dict, outlier=True, + historical=True, prev_event_ids=prev_event_ids_for_state_chain, # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 343e3c6b7b..805c49ac01 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -268,6 +268,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: Optional[dict] = None, require_consent: bool = True, outlier: bool = False, + historical: bool = False, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -293,6 +294,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. Returns: Tuple of event ID and stream ordering position @@ -337,6 +341,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): auth_event_ids=auth_event_ids, require_consent=require_consent, outlier=outlier, + historical=historical, ) prev_state_ids = await context.get_prev_state_ids() @@ -433,6 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, + historical: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -454,6 +460,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -487,6 +496,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room=new_room, require_consent=require_consent, outlier=outlier, + historical=historical, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, ) @@ -507,6 +517,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, + historical: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -530,6 +541,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted + back in time around some existing events. This is used to skip + a few checks and mark the event as backfilled. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -657,6 +671,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content=content, require_consent=require_consent, outlier=outlier, + historical=historical, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index fb26ee7ad7..b2cfe537df 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,8 +36,9 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -93,11 +94,14 @@ class RoomSummaryHandler: self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() + self._ratelimiter = Ratelimiter( + store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + ) # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. self._pagination_response_cache: ResponseCache[ - Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + Tuple[str, str, bool, Optional[int], Optional[int], Optional[str]] ] = ResponseCache( hs.get_clock(), "get_room_hierarchy", @@ -249,7 +253,7 @@ class RoomSummaryHandler: async def get_room_hierarchy( self, - requester: str, + requester: Requester, requested_room_id: str, suggested_only: bool = False, max_depth: Optional[int] = None, @@ -276,15 +280,24 @@ class RoomSummaryHandler: Returns: The JSON hierarchy dictionary. """ + await self._ratelimiter.ratelimit(requester) + # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. # # This is due to the pagination process mutating internal state, attempting # to process multiple requests for the same page will result in errors. return await self._pagination_response_cache.wrap( - (requested_room_id, suggested_only, max_depth, limit, from_token), + ( + requester.user.to_string(), + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ), self._get_room_hierarchy, - requester, + requester.user.to_string(), requested_room_id, suggested_only, max_depth, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 49fde01cf0..65c27bc64a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -365,6 +365,7 @@ class SsoHandler: sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -415,6 +416,8 @@ class SsoHandler: extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. + auth_provider_session_id: An optional session ID from the IdP. + Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -490,6 +493,7 @@ class SsoHandler: client_redirect_url, extra_login_attributes, new_user=new_user, + auth_provider_session_id=auth_provider_session_id, ) async def _call_attribute_mapper( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 891435c14d..96f37e9f42 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -334,6 +334,19 @@ class SyncHandler: full_state: bool, cache_context: ResponseCacheContext[SyncRequestKey], ) -> SyncResult: + """The start of the machinery that produces a /sync response. + + See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details. + + This method does high-level bookkeeping: + - tracking the kind of sync in the logging context + - deleting any to_device messages whose delivery has been acknowledged. + - deciding if we should dispatch an instant or delayed response + - marking the sync as being lazily loaded, if appropriate + + Computing the body of the response begins in the next method, + `current_sync_for_user`. + """ if since_token is None: sync_type = "initial_sync" elif full_state: @@ -363,7 +376,7 @@ class SyncHandler: sync_config, since_token, full_state=full_state ) else: - + # Otherwise, we wait for something to happen and report it to the user. async def current_sync_callback( before_token: StreamToken, after_token: StreamToken ) -> SyncResult: @@ -402,7 +415,12 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now.""" + """Generates the response body of a sync result, represented as a SyncResult. + + This is a wrapper around `generate_sync_result` which starts an open tracing + span to track the sync. See `generate_sync_result` for the next part of your + indoctrination. + """ with start_active_span("current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( @@ -560,7 +578,7 @@ class SyncHandler: # that have happened since `since_key` up to `end_key`, so we # can just use `get_room_events_stream_for_room`. # Otherwise, we want to return the last N events in the room - # in toplogical ordering. + # in topological ordering. if since_key: events, end_key = await self.store.get_room_events_stream_for_room( room_id, @@ -1028,7 +1046,7 @@ class SyncHandler: last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read", + receipt_type=ReceiptTypes.READ, ) notifs = await self.store.get_unread_event_push_actions_by_room_for_user( @@ -1042,7 +1060,18 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result.""" + """Generates the response body of a sync result. + + This is represented by a `SyncResult` struct, which is built from small pieces + using a `SyncResultBuilder`. See also + https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + the `sync_result_builder` is passed as a mutable ("inout") parameter to various + helper functions. These retrieve and process the data which forms the sync body, + often writing to the `sync_result_builder` to store their output. + + At the end, we transfer data from the `sync_result_builder` to a new `SyncResult` + instance to signify that the sync calculation is complete. + """ # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1344,14 +1373,22 @@ class SyncHandler: async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" ) -> Dict[str, Dict[str, JsonDict]]: - """Generates the account data portion of the sync response. Populates - `sync_result_builder` with the result. + """Generates the account data portion of the sync response. + + Account data (called "Client Config" in the spec) can be set either globally + or for a specific room. Account data consists of a list of events which + accumulate state, much like a room. + + This function retrieves global and per-room account data. The former is written + to the given `sync_result_builder`. The latter is returned directly, to be + later written to the `sync_result_builder` on a room-by-room basis. Args: sync_result_builder Returns: - A dictionary containing the per room account data. + A dictionary whose keys (room ids) map to the per room account data for that + room. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() @@ -1359,7 +1396,7 @@ class SyncHandler: if since_token and not sync_result_builder.full_state: ( - account_data, + global_account_data, account_data_by_room, ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key @@ -1370,23 +1407,23 @@ class SyncHandler: ) if push_rules_changed: - account_data["m.push_rules"] = await self.push_rules_for_user( + global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( - account_data, + global_account_data, account_data_by_room, ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = await self.push_rules_for_user( + global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) account_data_for_user = await sync_config.filter_collection.filter_account_data( [ {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() + for account_data_type, content in global_account_data.items() ] ) @@ -1460,18 +1497,31 @@ class SyncHandler: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. + In the response that reaches the client, rooms are divided into four categories: + `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of + room ids returned by this function. + Args: sync_result_builder account_data_by_room: Dictionary of per room account data Returns: - Returns a 4-tuple of - `(newly_joined_rooms, newly_joined_or_invited_users, - newly_left_rooms, newly_left_users)` + Returns a 4-tuple describing rooms the user has joined or left, and users who've + joined or left rooms any rooms the user is in. This gets used later in + `_generate_sync_entry_for_device_list`. + + Its entries are: + - newly_joined_rooms + - newly_joined_or_invited_or_knocked_users + - newly_left_rooms + - newly_left_users """ + since_token = sync_result_builder.since_token + + # 1. Start by fetching all ephemeral events in rooms we've joined (if required). user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None + since_token is None and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) @@ -1485,9 +1535,8 @@ class SyncHandler: ) sync_result_builder.now_token = now_token - # We check up front if anything has changed, if it hasn't then there is + # 2. We check up front if anything has changed, if it hasn't then there is # no point in going further. - since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: have_changed = await self._have_rooms_changed(sync_result_builder) @@ -1500,20 +1549,8 @@ class SyncHandler: logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - + # 3. Work out which rooms need reporting in the sync response. + ignored_users = await self._get_ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1523,7 +1560,6 @@ class SyncHandler: ) else: room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) - tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1534,6 +1570,8 @@ class SyncHandler: newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms + # 4. We need to apply further processing to `room_entries` (rooms considered + # joined or archived). async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( @@ -1552,31 +1590,13 @@ class SyncHandler: sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) - # Now we want to get any newly joined, invited or knocking users - newly_joined_or_invited_or_knocked_users = set() - newly_left_users = set() - if since_token: - for joined_sync in sync_result_builder.joined: - it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() - ) - for event in it: - if event.type == EventTypes.Member: - if ( - event.membership == Membership.JOIN - or event.membership == Membership.INVITE - or event.membership == Membership.KNOCK - ): - newly_joined_or_invited_or_knocked_users.add( - event.state_key - ) - else: - prev_content = event.unsigned.get("prev_content", {}) - prev_membership = prev_content.get("membership", None) - if prev_membership == Membership.JOIN: - newly_left_users.add(event.state_key) - - newly_left_users -= newly_joined_or_invited_or_knocked_users + # 5. Work out which users have joined or left rooms we're in. We use this + # to build the device_list part of the sync response in + # `_generate_sync_entry_for_device_list`. + ( + newly_joined_or_invited_or_knocked_users, + newly_left_users, + ) = sync_result_builder.calculate_user_changes() return ( set(newly_joined_rooms), @@ -1585,11 +1605,36 @@ class SyncHandler: newly_left_users, ) + async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: + """Retrieve the users ignored by the given user from their global account_data. + + Returns an empty set if + - there is no global account_data entry for ignored_users + - there is such an entry, but it's not a JSON object. + """ + # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) + ) + + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users: FrozenSet[str] = frozenset() + if ignored_account_data: + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) + return ignored_users + async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. + + Does not modify the `sync_result_builder`. """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1597,12 +1642,13 @@ class SyncHandler: assert since_token - # Get a list of membership change events that have happened. - rooms_changed = await self.store.get_membership_changes_for_user( + # Get a list of membership change events that have happened to the user + # requesting the sync. + membership_changes = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) - if rooms_changed: + if membership_changes: return True stream_id = since_token.room_key.stream @@ -1614,7 +1660,25 @@ class SyncHandler: async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync.""" + """Determine the changes in rooms to report to the user. + + Ideally, we want to report all events whose stream ordering `s` lies in the + range `since_token < s <= now_token`, where the two tokens are read from the + sync_result_builder. + + If there are too many events in that range to report, things get complicated. + In this situation we return a truncated list of the most recent events, and + indicate in the response that there is a "gap" of omitted events. Additionally: + + - we include a "state_delta", to describe the changes in state over the gap, + - we include all membership events applying to the user making the request, + even those in the gap. + + See the spec for the rationale: + https://spec.matrix.org/v1.1/client-server-api/#syncing + + The sync_result_builder is not modified by this function. + """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token @@ -1622,21 +1686,36 @@ class SyncHandler: assert since_token - # Get a list of membership change events that have happened. - rooms_changed = await self.store.get_membership_changes_for_user( + # The spec + # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync + # notes that membership events need special consideration: + # + # > When a sync is limited, the server MUST return membership events for events + # > in the gap (between since and the start of the returned timeline), regardless + # > as to whether or not they are redundant. + # + # We fetch such events here, but we only seem to use them for categorising rooms + # as newly joined, newly left, invited or knocked. + # TODO: we've already called this function and ran this query in + # _have_rooms_changed. We could keep the results in memory to avoid a + # second query, at the cost of more complicated source code. + membership_change_events = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} - for event in rooms_changed: + for event in membership_change_events: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - newly_joined_rooms = [] - newly_left_rooms = [] - room_entries = [] - invited = [] - knocked = [] + newly_joined_rooms: List[str] = [] + newly_left_rooms: List[str] = [] + room_entries: List[RoomSyncResultBuilder] = [] + invited: List[InvitedSyncResult] = [] + knocked: List[KnockedSyncResult] = [] for room_id, events in mem_change_events_by_room_id.items(): + # The body of this loop will add this room to at least one of the five lists + # above. Things get messy if you've e.g. joined, left, joined then left the + # room all in the same sync period. logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1691,6 +1770,7 @@ class SyncHandler: if not non_joins: continue + last_non_join = non_joins[-1] # Check if we have left the room. This can either be because we were # joined before *or* that we since joined and then left. @@ -1712,18 +1792,18 @@ class SyncHandler: newly_left_rooms.append(room_id) # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE + should_invite = last_non_join.membership == Membership.INVITE if should_invite: - if event.sender not in ignored_users: - invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if last_non_join.sender not in ignored_users: + invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join) if invite_room_sync: invited.append(invite_room_sync) # Only bother if our latest membership in the room is knock (and we haven't # been accepted/rejected in the meantime). - should_knock = non_joins[-1].membership == Membership.KNOCK + should_knock = last_non_join.membership == Membership.KNOCK if should_knock: - knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) + knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join) if knock_room_sync: knocked.append(knock_room_sync) @@ -1781,7 +1861,9 @@ class SyncHandler: timeline_limit = sync_config.filter_collection.timeline_limit() - # Get all events for rooms we're currently joined to. + # Get all events since the `from_key` in rooms we're currently joined to. + # If there are too many, we get the most recent events only. This leaves + # a "gap" in the timeline, as described by the spec for /sync. room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, @@ -1842,6 +1924,10 @@ class SyncHandler: ) -> _RoomChanges: """Returns entries for all rooms for the user. + Like `_get_rooms_changed`, but assumes the `since_token` is `None`. + + This function does not modify the sync_result_builder. + Args: sync_result_builder ignored_users: Set of users ignored by user. @@ -1853,16 +1939,9 @@ class SyncHandler: now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config - membership_list = ( - Membership.INVITE, - Membership.KNOCK, - Membership.JOIN, - Membership.LEAVE, - Membership.BAN, - ) - room_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, membership_list=membership_list + user_id=user_id, + membership_list=Membership.LIST, ) room_entries = [] @@ -2212,8 +2291,7 @@ def _calculate_state( # to only include membership events for the senders in the timeline. # In practice, we can do this by removing them from the p_ids list, # which is the list of relevant state we know we have already sent to the client. - # see https://github.com/matrix-org/synapse/pull/2970 - # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 + # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 if lazy_load_members: p_ids.difference_update( @@ -2262,6 +2340,39 @@ class SyncResultBuilder: groups: Optional[GroupsSyncResult] = None to_device: List[JsonDict] = attr.Factory(list) + def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: + """Work out which other users have joined or left rooms we are joined to. + + This data only is only useful for an incremental sync. + + The SyncResultBuilder is not modified by this function. + """ + newly_joined_or_invited_or_knocked_users = set() + newly_left_users = set() + if self.since_token: + for joined_sync in self.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if ( + event.membership == Membership.JOIN + or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK + ): + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) + else: + prev_content = event.unsigned.get("prev_content", {}) + prev_membership = prev_content.get("membership", None) + if prev_membership == Membership.JOIN: + newly_left_users.add(event.state_key) + + newly_left_users -= newly_joined_or_invited_or_knocked_users + return newly_joined_or_invited_or_knocked_users, newly_left_users + @attr.s(slots=True, auto_attribs=True) class RoomSyncResultBuilder: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 22c6174821..1676ebd057 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -90,7 +90,7 @@ class FollowerTypingHandler: self.wheel_timer = WheelTimer(bucket_size=5000) @wrap_as_background_process("typing._handle_timeouts") - def _handle_timeouts(self) -> None: + async def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() diff --git a/synapse/http/server.py b/synapse/http/server.py index 1af0d9a31d..91badb0b0a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -98,7 +98,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: "Failed handle request via %r: %r", request.request_metrics.name, request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) # Only respond with an error response if we haven't already started writing, @@ -150,7 +150,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR @@ -159,7 +159,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] ) if isinstance(error_template, str): diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 91ba93372c..6dd9b9ad03 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -79,6 +79,35 @@ def parse_integer( return parse_integer_from_args(args, name, default, required) +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, +) -> Optional[int]: + ... + + +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + *, + required: Literal[True], +) -> int: + ... + + +@overload +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, + required: bool = False, +) -> Optional[int]: + ... + + def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index af5fc407a8..478b527494 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -3,7 +3,7 @@ import time from logging import Handler, LogRecord from logging.handlers import MemoryHandler from threading import Thread -from typing import Optional +from typing import Optional, cast from twisted.internet.interfaces import IReactorCore @@ -56,7 +56,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): if reactor is None: from twisted.internet import reactor as global_reactor - reactor_to_use = global_reactor # type: ignore[assignment] + reactor_to_use = cast(IReactorCore, global_reactor) else: reactor_to_use = reactor diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 91ee5c8193..ceef57ad88 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,10 +20,25 @@ import os import platform import threading import time -from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr -from prometheus_client import Counter, Gauge, Histogram +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, CounterMetricFamily, @@ -32,6 +47,7 @@ from prometheus_client.core import ( ) from twisted.internet import reactor +from twisted.internet.base import ReactorBase from twisted.python.threadpool import ThreadPool import synapse @@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") class RegistryProxy: @staticmethod - def collect(): + def collect() -> Iterable[Metric]: for metric in REGISTRY.collect(): if not metric.name.startswith("__"): yield metric @@ -74,7 +90,7 @@ class LaterGauge: ] ) - def collect(self): + def collect(self) -> Iterable[Metric]: g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) @@ -93,10 +109,10 @@ class LaterGauge: yield g - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._register() - def _register(self): + def _register(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -105,7 +121,12 @@ class LaterGauge: all_gauges[self.name] = self -class InFlightGauge: +# `MetricsEntry` only makes sense when it is a `Protocol`, +# but `Protocol` can't be used as a `TypeVar` bound. +MetricsEntry = TypeVar("MetricsEntry") + + +class InFlightGauge(Generic[MetricsEntry]): """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -115,14 +136,19 @@ class InFlightGauge: callbacks. Args: - name (str) - desc (str) - labels (list[str]) - sub_metrics (list[str]): A list of sub metrics that the callbacks - will update. + name + desc + labels + sub_metrics: A list of sub metrics that the callbacks will update. """ - def __init__(self, name, desc, labels, sub_metrics): + def __init__( + self, + name: str, + desc: str, + labels: Sequence[str], + sub_metrics: Sequence[str], + ): self.name = name self.desc = desc self.labels = labels @@ -130,19 +156,25 @@ class InFlightGauge: # Create a class which have the sub_metrics values as attributes, which # default to 0 on initialization. Used to pass to registered callbacks. - self._metrics_class = attr.make_class( + self._metrics_class: Type[MetricsEntry] = attr.make_class( "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True ) # Counts number of in flight blocks for a given set of label values - self._registrations: Dict = {} + self._registrations: Dict[ + Tuple[str, ...], Set[Callable[[MetricsEntry], None]] + ] = {} # Protects access to _registrations self._lock = threading.Lock() self._register_with_collector() - def register(self, key, callback): + def register( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've entered a new block with labels `key`. `callback` gets called each time the metrics are collected. The same @@ -158,13 +190,17 @@ class InFlightGauge: with self._lock: self._registrations.setdefault(key, set()).add(callback) - def unregister(self, key, callback): + def unregister( + self, + key: Tuple[str, ...], + callback: Callable[[MetricsEntry], None], + ) -> None: """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) - def collect(self): + def collect(self) -> Iterable[Metric]: """Called by prometheus client when it reads metrics. Note: may be called by a separate thread. @@ -200,7 +236,7 @@ class InFlightGauge: gauge.add_metric(key, getattr(metrics, name)) yield gauge - def _register_with_collector(self): + def _register_with_collector(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -230,7 +266,7 @@ class GaugeBucketCollector: name: str, documentation: str, buckets: Iterable[float], - registry=REGISTRY, + registry: CollectorRegistry = REGISTRY, ): """ Args: @@ -257,12 +293,12 @@ class GaugeBucketCollector: registry.register(self) - def collect(self): + def collect(self) -> Iterable[Metric]: # Don't report metrics unless we've already collected some data if self._metric is not None: yield self._metric - def update_data(self, values: Iterable[float]): + def update_data(self, values: Iterable[float]) -> None: """Update the data to be reported by the metric The existing data is cleared, and each measurement in the input is assigned @@ -304,7 +340,7 @@ class GaugeBucketCollector: class CPUMetrics: - def __init__(self): + def __init__(self) -> None: ticks_per_sec = 100 try: # Try and get the system config @@ -314,7 +350,7 @@ class CPUMetrics: self.ticks_per_sec = ticks_per_sec - def collect(self): + def collect(self) -> Iterable[Metric]: if not HAVE_PROC_SELF_STAT: return @@ -364,7 +400,7 @@ gc_time = Histogram( class GCCounts: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): cm.add_metric([str(n)], m) @@ -382,7 +418,7 @@ if not running_on_pypy: class PyPyGCStats: - def collect(self): + def collect(self) -> Iterable[Metric]: # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. @@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None: class ReactorLastSeenMetric: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", @@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) _last_gc = [0.0, 0.0, 0.0] -def runUntilCurrentTimer(reactor, func): +F = TypeVar("F", bound=Callable[..., Any]) + + +def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F: @functools.wraps(func) - def f(*args, **kwargs): + def f(*args: Any, **kwargs: Any) -> Any: now = reactor.seconds() num_pending = 0 @@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func): return ret - return f + return cast(F, f) try: @@ -677,5 +716,5 @@ __all__ = [ "start_http_server", "LaterGauge", "InFlightGauge", - "BucketCollector", + "GaugeBucketCollector", ] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index bb9bcb5592..353d0a63b6 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -25,27 +25,25 @@ import math import threading from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Dict, List +from typing import Any, Dict, List, Type, Union from urllib.parse import parse_qs, urlparse -from prometheus_client import REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry +from prometheus_client.core import Sample from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.util import caches CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" -INF = float("inf") -MINUS_INF = float("-inf") - - -def floatToGoString(d): +def floatToGoString(d: Union[int, float]) -> str: d = float(d) - if d == INF: + if d == math.inf: return "+Inf" - elif d == MINUS_INF: + elif d == -math.inf: return "-Inf" elif math.isnan(d): return "NaN" @@ -60,7 +58,7 @@ def floatToGoString(d): return s -def sample_line(line, name): +def sample_line(line: Sample, name: str) -> str: if line.labels: labelstr = "{{{0}}}".format( ",".join( @@ -82,7 +80,7 @@ def sample_line(line, name): return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) -def generate_latest(registry, emit_help=False): +def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: # Trigger the cache metrics to be rescraped, which updates the common # metrics but do not produce metrics themselves @@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler): registry = REGISTRY - def do_GET(self): + def do_GET(self) -> None: registry = self.registry params = parse_qs(urlparse(self.path).query) @@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler): self.end_headers() self.wfile.write(output) - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: """Log nothing.""" @classmethod - def factory(cls, registry): + def factory(cls, registry: CollectorRegistry) -> Type: """Returns a dynamic MetricsHandler class tied to the passed registry. """ @@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): daemon_threads = True -def start_http_server(port, addr="", registry=REGISTRY): +def start_http_server( + port: int, addr: str = "", registry: CollectorRegistry = REGISTRY +) -> None: """Starts an HTTP server for prometheus metrics as a daemon thread""" CustomMetricsHandler = MetricsHandler.factory(registry) httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) @@ -252,10 +252,10 @@ class MetricsResource(Resource): isLeaf = True - def __init__(self, registry=REGISTRY): + def __init__(self, registry: CollectorRegistry = REGISTRY): self.registry = registry - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) response = generate_latest(self.registry) request.setHeader(b"Content-Length", str(len(response))) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 2ab599a334..53c508af91 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -15,19 +15,37 @@ import logging import threading from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Set, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Set, + Type, + TypeVar, + Union, + cast, +) +from prometheus_client import Metric from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + PreserveLoggingContext, +) from synapse.logging.opentracing import ( SynapseTags, noop_context_manager, start_active_span, ) -from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -116,7 +134,7 @@ class _Collector: before they are returned. """ - def collect(self): + def collect(self) -> Iterable[Metric]: global _background_processes_active_since_last_scrape # We swap out the _background_processes set with an empty one so that @@ -144,12 +162,12 @@ REGISTRY.register(_Collector()) class _BackgroundProcess: - def __init__(self, desc, ctx): + def __init__(self, desc: str, ctx: LoggingContext): self.desc = desc self._context = ctx - self._reported_stats = None + self._reported_stats: Optional[ContextResourceUsage] = None - def update_metrics(self): + def update_metrics(self) -> None: """Updates the metrics with values from this process.""" new_stats = self._context.get_resource_usage() if self._reported_stats is None: @@ -169,7 +187,16 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): +R = TypeVar("R") + + +def run_as_background_process( + desc: str, + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> "defer.Deferred[Optional[R]]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar args: positional args for func kwargs: keyword args for func - Returns: Deferred which returns the result of func, but note that it does not - follow the synapse logcontext rules. + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. """ - async def run(): + async def run() -> Optional[R]: with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar else: ctx = noop_context_manager() with ctx: - return await maybe_awaitable(func(*args, **kwargs)) + return await func(*args, **kwargs) except Exception: logger.exception( "Background process '%s' threw an exception", desc, ) + return None finally: _background_process_in_flight_count.labels(desc).dec() @@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar return defer.ensureDeferred(run()) -def wrap_as_background_process(desc): +F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]]) + + +def wrap_as_background_process(desc: str) -> Callable[[F], F]: """Decorator that wraps a function that gets called as a background process. - Equivalent of calling the function with `run_as_background_process` + Equivalent to calling the function with `run_as_background_process` """ - def wrap_as_background_process_inner(func): + def wrap_as_background_process_inner(func: F) -> F: @wraps(func) - def wrap_as_background_process_inner_2(*args, **kwargs): + def wrap_as_background_process_inner_2( + *args: Any, **kwargs: Any + ) -> "defer.Deferred[Optional[R]]": return run_as_background_process(desc, func, *args, **kwargs) - return wrap_as_background_process_inner_2 + return cast(F, wrap_as_background_process_inner_2) return wrap_as_background_process_inner @@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__init__("%s-%s" % (name, instance_id)) self._proc = _BackgroundProcess(name, self) - def start(self, rusage: "Optional[resource.struct_rusage]"): + def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" super().start(rusage) @@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext): with _bg_metrics_lock: _background_processes_active_since_last_scrape.add(self._proc) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 29ab6c0229..98ed9c0829 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -16,14 +16,16 @@ import ctypes import logging import os import re -from typing import Optional +from typing import Iterable, Optional + +from prometheus_client import Metric from synapse.metrics import REGISTRY, GaugeMetricFamily logger = logging.getLogger(__name__) -def _setup_jemalloc_stats(): +def _setup_jemalloc_stats() -> None: """Checks to see if jemalloc is loaded, and hooks up a collector to record statistics exposed by jemalloc. """ @@ -135,7 +137,7 @@ def _setup_jemalloc_stats(): class JemallocCollector: """Metrics for internal jemalloc stats.""" - def collect(self): + def collect(self) -> Iterable[Metric]: _jemalloc_refresh_stats() g = GaugeMetricFamily( @@ -185,7 +187,7 @@ def _setup_jemalloc_stats(): logger.debug("Added jemalloc stats") -def setup_jemalloc_stats(): +def setup_jemalloc_stats() -> None: """Try to setup jemalloc stats, if jemalloc is loaded.""" try: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6e7f5238fe..662e60bc33 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -24,6 +24,7 @@ from typing import ( List, Optional, Tuple, + TypeVar, Union, ) @@ -31,11 +32,48 @@ import attr import jinja2 from twisted.internet import defer -from twisted.web.resource import IResource +from twisted.web.resource import Resource from synapse.api.errors import SynapseError from synapse.events import EventBase -from synapse.events.presence_router import PresenceRouter +from synapse.events.presence_router import ( + GET_INTERESTED_USERS_CALLBACK, + GET_USERS_FOR_STATES_CALLBACK, + PresenceRouter, +) +from synapse.events.spamcheck import ( + CHECK_EVENT_FOR_SPAM_CALLBACK, + CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK, + CHECK_REGISTRATION_FOR_SPAM_CALLBACK, + CHECK_USERNAME_FOR_SPAM_CALLBACK, + USER_MAY_CREATE_ROOM_ALIAS_CALLBACK, + USER_MAY_CREATE_ROOM_CALLBACK, + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK, + USER_MAY_INVITE_CALLBACK, + USER_MAY_JOIN_ROOM_CALLBACK, + USER_MAY_PUBLISH_ROOM_CALLBACK, + USER_MAY_SEND_3PID_INVITE_CALLBACK, +) +from synapse.events.third_party_rules import ( + CHECK_EVENT_ALLOWED_CALLBACK, + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, + ON_CREATE_ROOM_CALLBACK, + ON_NEW_EVENT_CALLBACK, +) +from synapse.handlers.account_validity import ( + IS_USER_EXPIRED_CALLBACK, + ON_LEGACY_ADMIN_REQUEST, + ON_LEGACY_RENEW_CALLBACK, + ON_LEGACY_SEND_MAIL_CALLBACK, + ON_USER_REGISTRATION_CALLBACK, +) +from synapse.handlers.auth import ( + CHECK_3PID_AUTH_CALLBACK, + CHECK_AUTH_CALLBACK, + ON_LOGGED_OUT_CALLBACK, + AuthHandler, +) from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeHtmlResource, @@ -44,10 +82,19 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import ( + defer_to_thread, + make_deferred_yieldable, + run_in_background, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse from synapse.storage import DataStore +from synapse.storage.background_updates import ( + DEFAULT_BATCH_SIZE_CALLBACK, + MIN_BATCH_SIZE_CALLBACK, + ON_UPDATE_CALLBACK, +) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -61,12 +108,16 @@ from synapse.types import ( create_requester, ) from synapse.util import Clock +from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer + +T = TypeVar("T") + """ This package defines the 'stable' API which can be used by extension modules which are loaded into Synapse. @@ -114,7 +165,7 @@ class ModuleApi: can register new users etc if necessary. """ - def __init__(self, hs: "HomeServer", auth_handler): + def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None: self._hs = hs # TODO: Fix this type hint once the types for the data stores have been ironed @@ -156,47 +207,139 @@ class ModuleApi: ################################################################################# # The following methods should only be called during the module's initialisation. - @property - def register_spam_checker_callbacks(self): + def register_spam_checker_callbacks( + self, + check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, + user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, + user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, + user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, + user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, + user_may_create_room_with_invites: Optional[ + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK + ] = None, + user_may_create_room_alias: Optional[ + USER_MAY_CREATE_ROOM_ALIAS_CALLBACK + ] = None, + user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None, + check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None, + check_registration_for_spam: Optional[ + CHECK_REGISTRATION_FOR_SPAM_CALLBACK + ] = None, + check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, + ) -> None: """Registers callbacks for spam checking capabilities. Added in Synapse v1.37.0. """ - return self._spam_checker.register_callbacks + return self._spam_checker.register_callbacks( + check_event_for_spam=check_event_for_spam, + user_may_join_room=user_may_join_room, + user_may_invite=user_may_invite, + user_may_send_3pid_invite=user_may_send_3pid_invite, + user_may_create_room=user_may_create_room, + user_may_create_room_with_invites=user_may_create_room_with_invites, + user_may_create_room_alias=user_may_create_room_alias, + user_may_publish_room=user_may_publish_room, + check_username_for_spam=check_username_for_spam, + check_registration_for_spam=check_registration_for_spam, + check_media_file_for_spam=check_media_file_for_spam, + ) - @property - def register_account_validity_callbacks(self): + def register_account_validity_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ) -> None: """Registers callbacks for account validity capabilities. Added in Synapse v1.39.0. """ - return self._account_validity_handler.register_account_validity_callbacks + return self._account_validity_handler.register_account_validity_callbacks( + is_user_expired=is_user_expired, + on_user_registration=on_user_registration, + on_legacy_send_mail=on_legacy_send_mail, + on_legacy_renew=on_legacy_renew, + on_legacy_admin_request=on_legacy_admin_request, + ) - @property - def register_third_party_rules_callbacks(self): + def register_third_party_rules_callbacks( + self, + check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, + check_threepid_can_be_invited: Optional[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = None, + check_visibility_can_be_modified: Optional[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = None, + on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + ) -> None: """Registers callbacks for third party event rules capabilities. Added in Synapse v1.39.0. """ - return self._third_party_event_rules.register_third_party_rules_callbacks + return self._third_party_event_rules.register_third_party_rules_callbacks( + check_event_allowed=check_event_allowed, + on_create_room=on_create_room, + check_threepid_can_be_invited=check_threepid_can_be_invited, + check_visibility_can_be_modified=check_visibility_can_be_modified, + on_new_event=on_new_event, + ) - @property - def register_presence_router_callbacks(self): + def register_presence_router_callbacks( + self, + get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, + get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, + ) -> None: """Registers callbacks for presence router capabilities. Added in Synapse v1.42.0. """ - return self._presence_router.register_presence_router_callbacks + return self._presence_router.register_presence_router_callbacks( + get_users_for_states=get_users_for_states, + get_interested_users=get_interested_users, + ) - @property - def register_password_auth_provider_callbacks(self): + def register_password_auth_provider_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + auth_checkers: Optional[ + Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] + ] = None, + ) -> None: """Registers callbacks for password auth provider capabilities. Added in Synapse v1.46.0. """ - return self._password_auth_provider.register_password_auth_provider_callbacks + return self._password_auth_provider.register_password_auth_provider_callbacks( + check_3pid_auth=check_3pid_auth, + on_logged_out=on_logged_out, + auth_checkers=auth_checkers, + ) + + def register_background_update_controller_callbacks( + self, + on_update: ON_UPDATE_CALLBACK, + default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None, + ) -> None: + """Registers background update controller callbacks. + + Added in Synapse v1.49.0. + """ + + for db in self._hs.get_datastores().databases: + db.updates.register_update_controller_callbacks( + on_update=on_update, + default_batch_size=default_batch_size, + min_batch_size=min_batch_size, + ) - def register_web_resource(self, path: str, resource: IResource): + def register_web_resource(self, path: str, resource: Resource) -> None: """Registers a web resource to be served at the given path. This function should be called during initialisation of the module. @@ -216,7 +359,7 @@ class ModuleApi: # The following methods can be called by the module at any point in time. @property - def http_client(self): + def http_client(self) -> SimpleHttpClient: """Allows making outbound HTTP requests to remote resources. An instance of synapse.http.client.SimpleHttpClient @@ -226,7 +369,7 @@ class ModuleApi: return self._http_client @property - def public_room_list_manager(self): + def public_room_list_manager(self) -> "PublicRoomListManager": """Allows adding to, removing from and checking the status of rooms in the public room list. @@ -309,7 +452,7 @@ class ModuleApi: """ return await self._store.is_server_admin(UserID.from_string(user_id)) - def get_qualified_user_id(self, username): + def get_qualified_user_id(self, username: str) -> str: """Qualify a user id, if necessary Takes a user id provided by the user and adds the @ and :domain to @@ -318,10 +461,10 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - username (str): provided user id + username: provided user id Returns: - str: qualified @user:id + qualified @user:id """ if username.startswith("@"): return username @@ -357,22 +500,27 @@ class ModuleApi: """ return await self._store.user_get_threepids(user_id) - def check_user_exists(self, user_id): + def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": """Check if user exists. Added in Synapse v0.25.0. Args: - user_id (str): Complete @user:id + user_id: Complete @user:id Returns: - Deferred[str|None]: Canonical (case-corrected) user_id, or None + Canonical (case-corrected) user_id, or None if the user is not registered. """ return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks - def register(self, localpart, displayname=None, emails: Optional[List[str]] = None): + def register( + self, + localpart: str, + displayname: Optional[str] = None, + emails: Optional[List[str]] = None, + ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]: """Registers a new user with given localpart and optional displayname, emails. Also returns an access token for the new user. @@ -384,12 +532,12 @@ class ModuleApi: Added in Synapse v0.25.0. Args: - localpart (str): The localpart of the new user. - displayname (str|None): The displayname of the new user. - emails (List[str]): Emails to bind to the new user. + localpart: The localpart of the new user. + displayname: The displayname of the new user. + emails: Emails to bind to the new user. Returns: - Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token) + a 2-tuple of (user_id, access_token) """ logger.warning( "Using deprecated ModuleApi.register which creates a dummy user device." @@ -399,23 +547,26 @@ class ModuleApi: return user_id, access_token def register_user( - self, localpart, displayname=None, emails: Optional[List[str]] = None - ): + self, + localpart: str, + displayname: Optional[str] = None, + emails: Optional[List[str]] = None, + ) -> "defer.Deferred[str]": """Registers a new user with given localpart and optional displayname, emails. Added in Synapse v1.2.0. Args: - localpart (str): The localpart of the new user. - displayname (str|None): The displayname of the new user. - emails (List[str]): Emails to bind to the new user. + localpart: The localpart of the new user. + displayname: The displayname of the new user. + emails: Emails to bind to the new user. Raises: SynapseError if there is an error performing the registration. Check the 'errcode' property for more information on the reason for failure Returns: - defer.Deferred[str]: user_id + user_id """ return defer.ensureDeferred( self._hs.get_registration_handler().register_user( @@ -425,20 +576,25 @@ class ModuleApi: ) ) - def register_device(self, user_id, device_id=None, initial_display_name=None): + def register_device( + self, + user_id: str, + device_id: Optional[str] = None, + initial_display_name: Optional[str] = None, + ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]": """Register a device for a user and generate an access token. Added in Synapse v1.2.0. Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate + user_id: full canonical @user:id + device_id: The device ID to check, or None to generate a new one. - initial_display_name (str|None): An optional display name for the + initial_display_name: An optional display name for the device. Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + Tuple of device ID, access token, access token expiration time and refresh token """ return defer.ensureDeferred( self._hs.get_registration_handler().register_device( @@ -471,6 +627,7 @@ class ModuleApi: user_id: str, duration_in_ms: int = (2 * 60 * 1000), auth_provider_id: str = "", + auth_provider_session_id: Optional[str] = None, ) -> str: """Generate a login token suitable for m.login.token authentication @@ -488,11 +645,14 @@ class ModuleApi: return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, + auth_provider_session_id, duration_in_ms, ) @defer.inlineCallbacks - def invalidate_access_token(self, access_token): + def invalidate_access_token( + self, access_token: str + ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user Added in Synapse v0.25.0. @@ -524,14 +684,20 @@ class ModuleApi: self._auth_handler.delete_access_token(access_token) ) - def run_db_interaction(self, desc, func, *args, **kwargs): + def run_db_interaction( + self, + desc: str, + func: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> "defer.Deferred[T]": """Run a function with a database connection Added in Synapse v0.25.0. Args: - desc (str): description for the transaction, for metrics etc - func (func): function to be run. Passed a database cursor object + desc: description for the transaction, for metrics etc + func: function to be run. Passed a database cursor object as well as *args and **kwargs *args: positional args to be passed to func **kwargs: named args to be passed to func @@ -545,7 +711,7 @@ class ModuleApi: def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str - ): + ) -> None: """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. @@ -575,7 +741,7 @@ class ModuleApi: client_redirect_url: str, new_user: bool = False, auth_provider_id: str = "<unknown>", - ): + ) -> None: """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. @@ -814,11 +980,11 @@ class ModuleApi: self, f: Callable, msec: float, - *args, + *args: object, desc: Optional[str] = None, run_on_all_instances: bool = False, - **kwargs, - ): + **kwargs: object, + ) -> None: """Wraps a function as a background process and calls it repeatedly. NOTE: Will only run on the instance that is configured to run @@ -849,9 +1015,7 @@ class ModuleApi: run_as_background_process, msec, desc, - f, - *args, - **kwargs, + lambda: maybe_awaitable(f(*args, **kwargs)), ) else: logger.warning( @@ -859,13 +1023,18 @@ class ModuleApi: f, ) + async def sleep(self, seconds: float) -> None: + """Sleeps for the given number of seconds.""" + + await self._clock.sleep(seconds) + async def send_mail( self, recipient: str, subject: str, html: str, text: str, - ): + ) -> None: """Send an email on behalf of the homeserver. Added in Synapse v1.39.0. @@ -903,7 +1072,7 @@ class ModuleApi: A list containing the loaded templates, with the orders matching the one of the filenames parameter. """ - return self._hs.config.read_templates( + return self._hs.config.server.read_templates( filenames, (td for td in (self.custom_template_dir, custom_template_directory) if td), ) @@ -1013,6 +1182,26 @@ class ModuleApi: return {key: state_events[event_id] for key, event_id in state_ids.items()} + async def defer_to_thread( + self, + f: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> T: + """Runs the given function in a separate thread from Synapse's thread pool. + + Added in Synapse v1.49.0. + + Args: + f: The function to run. + args: The function's arguments. + kwargs: The function's keyword arguments. + + Returns: + The return value of the function once ran in a thread. + """ + return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index cf5abdfbda..4f13c0418a 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -21,6 +21,8 @@ from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer +from synapse.push.push_types import EmailReason +from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.util.threepids import validate_email if TYPE_CHECKING: @@ -190,7 +192,7 @@ class EmailPusher(Pusher): # we then consider all previously outstanding notifications # to be delivered. - reason = { + reason: EmailReason = { "room_id": push_action["room_id"], "now": self.clock.time_msec(), "received_at": received_at, @@ -275,7 +277,7 @@ class EmailPusher(Pusher): return may_send_at async def sent_notif_update_throttle( - self, room_id: str, notified_push_action: dict + self, room_id: str, notified_push_action: EmailPushAction ) -> None: # We have sent a notification, so update the throttle accordingly. # If the event that triggered the notif happened more than @@ -315,7 +317,9 @@ class EmailPusher(Pusher): self.pusher_id, room_id, self.throttle_params[room_id] ) - async def send_notification(self, push_actions: List[dict], reason: dict) -> None: + async def send_notification( + self, push_actions: List[EmailPushAction], reason: EmailReason + ) -> None: logger.info("Sending notif email for user %r", self.user_id) await self.mailer.send_notification_mail( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index dbf4ad7f97..3fa603ccb7 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -26,6 +26,7 @@ from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException +from synapse.storage.databases.main.event_push_actions import HttpPushAction from . import push_rule_evaluator, push_tools @@ -273,7 +274,7 @@ class HttpPusher(Pusher): ) break - async def _process_one(self, push_action: dict) -> bool: + async def _process_one(self, push_action: HttpPushAction) -> bool: if "notify" not in push_action["actions"]: return True diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ce299ba3da..ba4f866487 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -14,7 +14,7 @@ import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar import bleach import jinja2 @@ -28,6 +28,14 @@ from synapse.push.presentable_names import ( descriptor_from_member_events, name_from_member_event, ) +from synapse.push.push_types import ( + EmailReason, + MessageVars, + NotifVars, + RoomVars, + TemplateVars, +) +from synapse.storage.databases.main.event_push_actions import EmailPushAction from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID from synapse.util.async_helpers import concurrently_execute @@ -135,7 +143,7 @@ class Mailer: % urllib.parse.urlencode(params) ) - template_vars = {"link": link} + template_vars: TemplateVars = {"link": link} await self.send_email( email_address, @@ -165,7 +173,7 @@ class Mailer: % urllib.parse.urlencode(params) ) - template_vars = {"link": link} + template_vars: TemplateVars = {"link": link} await self.send_email( email_address, @@ -196,7 +204,7 @@ class Mailer: % urllib.parse.urlencode(params) ) - template_vars = {"link": link} + template_vars: TemplateVars = {"link": link} await self.send_email( email_address, @@ -210,8 +218,8 @@ class Mailer: app_id: str, user_id: str, email_address: str, - push_actions: Iterable[Dict[str, Any]], - reason: Dict[str, Any], + push_actions: Iterable[EmailPushAction], + reason: EmailReason, ) -> None: """ Send email regarding a user's room notifications @@ -230,7 +238,7 @@ class Mailer: [pa["event_id"] for pa in push_actions] ) - notifs_by_room: Dict[str, List[Dict[str, Any]]] = {} + notifs_by_room: Dict[str, List[EmailPushAction]] = {} for pa in push_actions: notifs_by_room.setdefault(pa["room_id"], []).append(pa) @@ -258,7 +266,7 @@ class Mailer: # actually sort our so-called rooms_in_order list, most recent room first rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) - rooms: List[Dict[str, Any]] = [] + rooms: List[RoomVars] = [] for r in rooms_in_order: roomvars = await self._get_room_vars( @@ -289,7 +297,7 @@ class Mailer: notifs_by_room, state_by_room, notif_events, reason ) - template_vars = { + template_vars: TemplateVars = { "user_display_name": user_display_name, "unsubscribe_link": self._make_unsubscribe_link( user_id, app_id, email_address @@ -302,10 +310,10 @@ class Mailer: await self.send_email(email_address, summary_text, template_vars) async def send_email( - self, email_address: str, subject: str, extra_template_vars: Dict[str, Any] + self, email_address: str, subject: str, extra_template_vars: TemplateVars ) -> None: """Send an email with the given information and template text""" - template_vars = { + template_vars: TemplateVars = { "app_name": self.app_name, "server_name": self.hs.config.server.server_name, } @@ -327,10 +335,10 @@ class Mailer: self, room_id: str, user_id: str, - notifs: Iterable[Dict[str, Any]], + notifs: Iterable[EmailPushAction], notif_events: Dict[str, EventBase], room_state_ids: StateMap[str], - ) -> Dict[str, Any]: + ) -> RoomVars: """ Generate the variables for notifications on a per-room basis. @@ -356,7 +364,7 @@ class Mailer: room_name = await calculate_room_name(self.store, room_state_ids, user_id) - room_vars: Dict[str, Any] = { + room_vars: RoomVars = { "title": room_name, "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], @@ -417,11 +425,11 @@ class Mailer: async def _get_notif_vars( self, - notif: Dict[str, Any], + notif: EmailPushAction, user_id: str, notif_event: EventBase, room_state_ids: StateMap[str], - ) -> Dict[str, Any]: + ) -> NotifVars: """ Generate the variables for a single notification. @@ -442,7 +450,7 @@ class Mailer: after_limit=CONTEXT_AFTER, ) - ret = { + ret: NotifVars = { "link": self._make_notif_link(notif), "ts": notif["received_ts"], "messages": [], @@ -461,8 +469,8 @@ class Mailer: return ret async def _get_message_vars( - self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str] - ) -> Optional[Dict[str, Any]]: + self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str] + ) -> Optional[MessageVars]: """ Generate the variables for a single event, if possible. @@ -494,7 +502,9 @@ class Mailer: if sender_state_event: sender_name = name_from_member_event(sender_state_event) - sender_avatar_url = sender_state_event.content.get("avatar_url") + sender_avatar_url: Optional[str] = sender_state_event.content.get( + "avatar_url" + ) else: # No state could be found, fallback to the MXID. sender_name = event.sender @@ -504,7 +514,7 @@ class Mailer: # sender_hash % the number of default images to choose from sender_hash = string_ordinal_total(event.sender) - ret = { + ret: MessageVars = { "event_type": event.type, "is_historical": event.event_id != notif["event_id"], "id": event.event_id, @@ -519,6 +529,8 @@ class Mailer: return ret msgtype = event.content.get("msgtype") + if not isinstance(msgtype, str): + msgtype = None ret["msgtype"] = msgtype @@ -533,7 +545,7 @@ class Mailer: return ret def _add_text_message_vars( - self, messagevars: Dict[str, Any], event: EventBase + self, messagevars: MessageVars, event: EventBase ) -> None: """ Potentially add a sanitised message body to the message variables. @@ -543,8 +555,8 @@ class Mailer: event: The event under consideration. """ msgformat = event.content.get("format") - - messagevars["format"] = msgformat + if not isinstance(msgformat, str): + msgformat = None formatted_body = event.content.get("formatted_body") body = event.content.get("body") @@ -555,7 +567,7 @@ class Mailer: messagevars["body_text_html"] = safe_text(body) def _add_image_message_vars( - self, messagevars: Dict[str, Any], event: EventBase + self, messagevars: MessageVars, event: EventBase ) -> None: """ Potentially add an image URL to the message variables. @@ -570,7 +582,7 @@ class Mailer: async def _make_summary_text_single_room( self, room_id: str, - notifs: List[Dict[str, Any]], + notifs: List[EmailPushAction], room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], user_id: str, @@ -685,10 +697,10 @@ class Mailer: async def _make_summary_text( self, - notifs_by_room: Dict[str, List[Dict[str, Any]]], + notifs_by_room: Dict[str, List[EmailPushAction]], room_state_ids: Dict[str, StateMap[str]], notif_events: Dict[str, EventBase], - reason: Dict[str, Any], + reason: EmailReason, ) -> str: """ Make a summary text for the email when multiple rooms have notifications. @@ -718,7 +730,7 @@ class Mailer: async def _make_summary_text_from_member_events( self, room_id: str, - notifs: List[Dict[str, Any]], + notifs: List[EmailPushAction], room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], ) -> str: @@ -805,7 +817,7 @@ class Mailer: base_url = "https://matrix.to/#" return "%s/%s" % (base_url, room_id) - def _make_notif_link(self, notif: Dict[str, str]) -> str: + def _make_notif_link(self, notif: EmailPushAction) -> str: """ Generate a link to open an event in the web client. diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9c85200c0f..da641aca47 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) badge = len(invites) diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py new file mode 100644 index 0000000000..8d16ab62ce --- /dev/null +++ b/synapse/push/push_types.py @@ -0,0 +1,136 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +from typing_extensions import TypedDict + + +class EmailReason(TypedDict, total=False): + """ + Information on the event that triggered the email to be sent + + room_id: the ID of the room the event was sent in + now: timestamp in ms when the email is being sent out + room_name: a human-readable name for the room the event was sent in + received_at: the time in milliseconds at which the event was received + delay_before_mail_ms: the amount of time in milliseconds Synapse always waits + before ever emailing about a notification (to give the user a chance to respond + to other push or notice the window) + last_sent_ts: the time in milliseconds at which a notification was last sent + for an event in this room + throttle_ms: the minimum amount of time in milliseconds between two + notifications can be sent for this room + """ + + room_id: str + now: int + room_name: Optional[str] + received_at: int + delay_before_mail_ms: int + last_sent_ts: int + throttle_ms: int + + +class MessageVars(TypedDict, total=False): + """ + Details about a specific message to include in a notification + + event_type: the type of the event + is_historical: a boolean, which is `False` if the message is the one + that triggered the notification, `True` otherwise + id: the ID of the event + ts: the time in milliseconds at which the event was sent + sender_name: the display name for the event's sender + sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's + sender + sender_hash: a hash of the user ID of the sender + msgtype: the type of the message + body_text_html: html representation of the message + body_text_plain: plaintext representation of the message + image_url: mxc url of an image, when "msgtype" is "m.image" + """ + + event_type: str + is_historical: bool + id: str + ts: int + sender_name: str + sender_avatar_url: Optional[str] + sender_hash: int + msgtype: Optional[str] + body_text_html: str + body_text_plain: str + image_url: str + + +class NotifVars(TypedDict): + """ + Details about an event we are about to include in a notification + + link: a `matrix.to` link to the event + ts: the time in milliseconds at which the event was received + messages: a list of messages containing one message before the event, the + message in the event, and one message after the event. + """ + + link: str + ts: Optional[int] + messages: List[MessageVars] + + +class RoomVars(TypedDict): + """ + Represents a room containing events to include in the email. + + title: a human-readable name for the room + hash: a hash of the ID of the room + invite: a boolean, which is `True` if the room is an invite the user hasn't + accepted yet, `False` otherwise + notifs: a list of events, or an empty list if `invite` is `True`. + link: a `matrix.to` link to the room + avator_url: url to the room's avator + """ + + title: Optional[str] + hash: int + invite: bool + notifs: List[NotifVars] + link: str + avatar_url: Optional[str] + + +class TemplateVars(TypedDict, total=False): + """ + Generic structure for passing to the email sender, can hold all the fields used in email templates. + + app_name: name of the app/service this homeserver is associated with + server_name: name of our own homeserver + link: a link to include into the email to be sent + user_display_name: the display name for the user receiving the notification + unsubscribe_link: the link users can click to unsubscribe from email notifications + summary_text: a summary of the notification(s). The text used can be customised + by configuring the various settings in the `email.subjects` section of the + configuration file. + rooms: a list of rooms containing events to include in the email + reason: information on the event that triggered the email to be sent + """ + + app_name: str + server_name: str + link: str + user_display_name: str + unsubscribe_link: str + summary_text: str + rooms: List[RoomVars] + reason: EmailReason diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 154e5b7028..7d26954244 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -86,7 +86,7 @@ REQUIREMENTS = [ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.0", + "ijson>=3.1", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 0db419ea57..daacc34cea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost, should_issue_refresh_token, + auth_provider_id, + auth_provider_session_id, ): """ Args: @@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, "should_issue_refresh_token": should_issue_refresh_token, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, } async def _handle_request(self, request, user_id): @@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] should_issue_refresh_token = content["should_issue_refresh_token"] + auth_provider_id = content["auth_provider_id"] + auth_provider_session_id = content["auth_provider_session_id"] res = await self.registration_handler.register_device_inner( user_id, @@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) return 200, res diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 8c1bf9227a..fa132d10b4 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -14,10 +14,18 @@ from typing import List, Optional, Tuple from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import _load_current_id +from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id -class SlavedIdTracker: +class SlavedIdTracker(AbstractStreamIdTracker): + """Tracks the "current" stream ID of a stream with a single writer. + + See `AbstractStreamIdTracker` for more details. + + Note that this class does not work correctly when there are multiple + writers. + """ + def __init__( self, db_conn: LoggingDatabaseConnection, @@ -36,17 +44,7 @@ class SlavedIdTracker: self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: - """ - - Returns: - int - """ return self._current def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to - `get_current_token`. - """ return self.get_current_token() diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 4d5f862862..7541e21de9 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - # We assert this for the benefit of mypy - assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) - if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 55326877fd..a9d85f4f6c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING from prometheus_client import Counter -from twisted.internet.protocol import Factory +from twisted.internet.protocol import ServerFactory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand @@ -38,7 +38,7 @@ stream_updates_counter = Counter( logger = logging.getLogger(__name__) -class ReplicationStreamProtocolFactory(Factory): +class ReplicationStreamProtocolFactory(ServerFactory): """Factory for new replication connections.""" def __init__(self, hs: "HomeServer"): diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index a030e9299e..a390cfcb74 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Tuple, Type import attr @@ -157,7 +157,7 @@ class EventsStream(Stream): # now we fetch up to that many rows from the events table - event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( + event_rows = await self._store.get_all_new_forward_event_rows( instance_name, from_token, current_token, target_row_count ) @@ -191,7 +191,7 @@ class EventsStream(Stream): # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. - ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( instance_name, from_token, upper_limit ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e04af705eb..cebdeecb81 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from synapse.http.server import HttpServer, JsonResource from synapse.rest import admin @@ -62,6 +62,8 @@ from synapse.rest.client import ( if TYPE_CHECKING: from synapse.server import HomeServer +RegisterServletsFunc = Callable[["HomeServer", HttpServer], None] + class ClientRestResource(JsonResource): """Matrix Client API REST resource. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 81e98f81d6..701c609c12 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -17,6 +17,7 @@ import logging import platform +from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple import synapse @@ -28,6 +29,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin.background_updates import ( BackgroundUpdateEnabledRestServlet, BackgroundUpdateRestServlet, + BackgroundUpdateStartJobRestServlet, ) from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, @@ -38,6 +40,10 @@ from synapse.rest.admin.event_reports import ( EventReportDetailRestServlet, EventReportsRestServlet, ) +from synapse.rest.admin.federation import ( + DestinationsRestServlet, + ListDestinationsRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( @@ -46,6 +52,9 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( + BlockRoomRestServlet, + DeleteRoomStatusByDeleteIdRestServlet, + DeleteRoomStatusByRoomIdRestServlet, ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, @@ -53,6 +62,7 @@ from synapse.rest.admin.rooms import ( RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, + RoomRestV2Servlet, RoomStateRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -93,12 +103,12 @@ class VersionServlet(RestServlet): } def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - return 200, self.res + return HTTPStatus.OK, self.res class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" + "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -125,7 +135,7 @@ class PurgeHistoryRestServlet(RestServlet): event = await self.store.get_event(event_id) if event.room_id != room_id: - raise SynapseError(400, "Event is for wrong room.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.") # RoomStreamToken expects [int] not Optional[int] assert event.internal_metadata.stream_ordering is not None @@ -139,7 +149,9 @@ class PurgeHistoryRestServlet(RestServlet): ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "purge_up_to_ts must be an int", + errcode=Codes.BAD_JSON, ) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) @@ -155,7 +167,9 @@ class PurgeHistoryRestServlet(RestServlet): stream_ordering, ) raise SynapseError( - 404, "there is no event to be purged", errcode=Codes.NOT_FOUND + HTTPStatus.NOT_FOUND, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) @@ -168,7 +182,7 @@ class PurgeHistoryRestServlet(RestServlet): ) else: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "must specify purge_up_to_event_id or purge_up_to_ts", errcode=Codes.BAD_JSON, ) @@ -177,11 +191,11 @@ class PurgeHistoryRestServlet(RestServlet): room_id, token, delete_local_events=delete_local_events ) - return 200, {"purge_id": purge_id} + return HTTPStatus.OK, {"purge_id": purge_id} class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() @@ -196,7 +210,7 @@ class PurgeHistoryStatusRestServlet(RestServlet): if purge_status is None: raise NotFoundError("purge id '%s' not found" % purge_id) - return 200, purge_status.asdict() + return HTTPStatus.OK, purge_status.asdict() ######################################################################################## @@ -220,10 +234,14 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomRestV2Servlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) + DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server) + DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) @@ -247,12 +265,15 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + DestinationsRestServlet(hs).register(http_server) + ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: SendServerNoticeServlet(hs).register(http_server) BackgroundUpdateEnabledRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index d9a2f6ca15..399b205aaf 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from http import HTTPStatus from typing import Iterable, Pattern from synapse.api.auth import Auth @@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: """ is_admin = await auth.is_server_admin(user_id) if not is_admin: - raise AuthError(403, "You are not a server admin") + raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 0d0183bf20..6ec00ce0b9 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict if TYPE_CHECKING: @@ -29,37 +34,34 @@ logger = logging.getLogger(__name__) class BackgroundUpdateEnabledRestServlet(RestServlet): """Allows temporarily disabling background updates""" - PATTERNS = admin_patterns("/background_updates/enabled") + PATTERNS = admin_patterns("/background_updates/enabled$") def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - self.data_stores = hs.get_datastores() + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) - enabled = all(db.updates.enabled for db in self.data_stores.databases) + enabled = all(db.updates.enabled for db in self._data_stores.databases) - return 200, {"enabled": enabled} + return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) enabled = body.get("enabled", True) if not isinstance(enabled, bool): - raise SynapseError(400, "'enabled' parameter must be a boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean" + ) - for db in self.data_stores.databases: + for db in self._data_stores.databases: db.updates.enabled = enabled # If we're re-enabling them ensure that we start the background @@ -67,32 +69,28 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): if enabled: db.updates.start_doing_background_updates() - return 200, {"enabled": enabled} + return HTTPStatus.OK, {"enabled": enabled} class BackgroundUpdateRestServlet(RestServlet): """Fetch information about background updates""" - PATTERNS = admin_patterns("/background_updates/status") + PATTERNS = admin_patterns("/background_updates/status$") def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - self.data_stores = hs.get_datastores() + self._auth = hs.get_auth() + self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) - enabled = all(db.updates.enabled for db in self.data_stores.databases) + enabled = all(db.updates.enabled for db in self._data_stores.databases) current_updates = {} - for db in self.data_stores.databases: + for db in self._data_stores.databases: update = db.updates.get_current_update() if not update: continue @@ -104,4 +102,71 @@ class BackgroundUpdateRestServlet(RestServlet): "average_items_per_ms": update.average_items_per_ms(), } - return 200, {"enabled": enabled, "current_updates": current_updates} + return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates} + + +class BackgroundUpdateStartJobRestServlet(RestServlet): + """Allows to start specific background updates""" + + PATTERNS = admin_patterns("/background_updates/start_job$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["job_name"]) + + job_name = body["job_name"] + + if job_name == "populate_stats_process_rooms": + jobs = [ + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ] + elif job_name == "regenerate_directory": + jobs = [ + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + "depends_on": "", + }, + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ] + else: + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name") + + try: + await self._store.db_pool.simple_insert_many( + table="background_updates", + values=jobs, + desc=f"admin_api_run_{job_name}", + ) + except self._store.db_pool.engine.module.IntegrityError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Job %s is already in queue of background updates." % (job_name,), + ) + + self._store.db_pool.updates.start_doing_background_updates() + + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 80fbf32f17..062a33d28d 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError @@ -41,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -52,8 +53,8 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -62,7 +63,7 @@ class DeviceRestServlet(RestServlet): device = await self.device_handler.get_device( target_user.to_string(), device_id ) - return 200, device + return HTTPStatus.OK, device async def on_DELETE( self, request: SynapseRequest, user_id: str, device_id: str @@ -70,15 +71,15 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") await self.device_handler.delete_device(target_user.to_string(), device_id) - return 200, {} + return HTTPStatus.OK, {} async def on_PUT( self, request: SynapseRequest, user_id: str, device_id: str @@ -86,8 +87,8 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -97,7 +98,7 @@ class DeviceRestServlet(RestServlet): await self.device_handler.update_device( target_user.to_string(), device_id, body ) - return 200, {} + return HTTPStatus.OK, {} class DevicesRestServlet(RestServlet): @@ -108,14 +109,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -123,15 +120,15 @@ class DevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return 200, {"devices": devices, "total": len(devices)} + return HTTPStatus.OK, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): @@ -143,10 +140,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -154,8 +151,8 @@ class DeleteDevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) if u is None: @@ -167,4 +164,4 @@ class DeleteDevicesRestServlet(RestServlet): await self.device_handler.delete_devices( target_user.to_string(), body["devices"] ) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index bbfcaf723b..38477f8ead 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -51,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -66,21 +66,23 @@ class EventReportsRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The start parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "The limit parameter must be a positive integer.", errcode=Codes.INVALID_PARAM, ) if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) event_reports, total = await self.store.get_event_reports_paginate( @@ -90,7 +92,7 @@ class EventReportsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(event_reports) - return 200, ret + return HTTPStatus.OK, ret class EventReportDetailRestServlet(RestServlet): @@ -112,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -127,13 +128,17 @@ class EventReportDetailRestServlet(RestServlet): try: resolved_report_id = int(report_id) except ValueError: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) if resolved_report_id < 0: - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) ret = await self.store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py new file mode 100644 index 0000000000..50d88c9109 --- /dev/null +++ b/synapse/rest/admin/federation.py @@ -0,0 +1,135 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.storage.databases.main.transactions import DestinationSortOrder +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListDestinationsRestServlet(RestServlet): + """Get request to list all destinations. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations?from=0&limit=10 + + returns: + 200 OK with list of destinations if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `destination` can be used to filter by destination. + The parameter `order_by` can be used to order the result. + """ + + PATTERNS = admin_patterns("/federation/destinations$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + destination = parse_string(request, "destination") + + order_by = parse_string( + request, + "order_by", + default=DestinationSortOrder.DESTINATION.value, + allowed_values=[dest.value for dest in DestinationSortOrder], + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + + destinations, total = await self._store.get_destinations_paginate( + start, limit, destination, order_by, direction + ) + response = {"destinations": destinations, "total": total} + if (start + limit) < total: + response["next_token"] = str(start + len(destinations)) + + return HTTPStatus.OK, response + + +class DestinationsRestServlet(RestServlet): + """Get details of a destination. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations/<destination> + + returns: + 200 OK with details of a destination if success otherwise an error. + """ + + PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + destination_retry_timings = await self._store.get_destination_retry_timings( + destination + ) + + if not destination_retry_timings: + raise NotFoundError("Unknown destination") + + last_successful_stream_ordering = ( + await self._store.get_destination_last_successful_stream_ordering( + destination + ) + ) + + response = { + "destination": destination, + "failure_ts": destination_retry_timings.failure_ts, + "retry_last_ts": destination_retry_timings.retry_last_ts, + "retry_interval": destination_retry_timings.retry_interval, + "last_successful_stream_ordering": last_successful_stream_ordering, + } + + return HTTPStatus.OK, response diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index 68a3ba3cb7..cd697e180e 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError @@ -29,7 +30,7 @@ logger = logging.getLogger(__name__) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() @@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine_id(group_id): - raise SynapseError(400, "Can only delete local groups") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") await self.group_server.delete_group(group_id, requester.user.to_string()) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 30a687d234..7236e4027f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -14,9 +14,10 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -40,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"), + *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet): room_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByUser(RestServlet): @@ -70,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet): user_id, requester.user.to_string() ) - return 200, {"num_quarantined": num_quarantined} + return HTTPStatus.OK, {"num_quarantined": num_quarantined} class QuarantineMediaByID(RestServlet): @@ -98,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet): server_name, media_id, requester.user.to_string() ) - return 200, {} + return HTTPStatus.OK, {} class UnquarantineMediaByID(RestServlet): @@ -127,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -137,8 +138,7 @@ class UnquarantineMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -147,13 +147,13 @@ class UnquarantineMediaByID(RestServlet): # Remove from quarantine this media id await self.store.quarantine_media_by_id(server_name, media_id, None) - return 200, {} + return HTTPStatus.OK, {} class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -162,21 +162,20 @@ class ProtectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) # Protect this media id await self.store.mark_local_media_as_safe(media_id, safe=True) - return 200, {} + return HTTPStatus.OK, {} class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -185,21 +184,20 @@ class UnprotectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) # Unprotect this media id await self.store.mark_local_media_as_safe(media_id, safe=False) - return 200, {} + return HTTPStatus.OK, {} class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -208,14 +206,11 @@ class ListMediaInRoom(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(403, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) - return 200, {"local": local_mxcs, "remote": remote_mxcs} + return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs} class PurgeMediaCacheRestServlet(RestServlet): @@ -233,13 +228,13 @@ class PurgeMediaCacheRestServlet(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, @@ -247,13 +242,13 @@ class PurgeMediaCacheRestServlet(RestServlet): ret = await self.media_repository.delete_old_remote_media(before_ts) - return 200, ret + return HTTPStatus.OK, ret class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -267,7 +262,7 @@ class DeleteMediaByID(RestServlet): await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") if await self.store.get_local_media(media_id) is None: raise NotFoundError("Unknown media") @@ -277,7 +272,7 @@ class DeleteMediaByID(RestServlet): deleted_media, total = await self.media_repository.delete_local_media_ids( [media_id] ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class DeleteMediaByDateSize(RestServlet): @@ -285,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -304,26 +299,26 @@ class DeleteMediaByDateSize(RestServlet): if before_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts must be a positive integer.", errcode=Codes.INVALID_PARAM, ) elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter before_ts you provided is from the year 1970. " + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, ) if size_gt < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter size_gt must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if self.server_name != server_name: - raise SynapseError(400, "Can only delete local media") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" @@ -333,7 +328,7 @@ class DeleteMediaByDateSize(RestServlet): deleted_media, total = await self.media_repository.delete_old_local_media( before_ts, size_gt, keep_profiles ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} class UserMediaRestServlet(RestServlet): @@ -352,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -369,7 +364,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -380,14 +375,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -402,16 +397,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -425,7 +411,7 @@ class UserMediaRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(media) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str @@ -436,7 +422,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -447,14 +433,14 @@ class UserMediaRestServlet(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -469,16 +455,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -492,7 +469,7 @@ class UserMediaRestServlet(RestServlet): ([row["media_id"] for row in media]) ) - return 200, {"deleted_media": deleted_media, "total": total} + return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index aba48f6e7b..04948b6408 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -14,6 +14,7 @@ import logging import string +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -69,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -77,7 +77,7 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return 200, {"registration_tokens": token_list} + return HTTPStatus.OK, {"registration_tokens": token_list} class NewRegistrationTokenRestServlet(RestServlet): @@ -108,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -123,16 +122,20 @@ class NewRegistrationTokenRestServlet(RestServlet): if "token" in body: token = body["token"] if not isinstance(token, str): - raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "token must be a string", + Codes.INVALID_PARAM, + ) if not (0 < len(token) <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must not be empty and must not be longer than 64 characters", Codes.INVALID_PARAM, ) if not set(token).issubset(self.allowed_chars_set): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "token must consist only of characters matched by the regex [A-Za-z0-9-_]", Codes.INVALID_PARAM, ) @@ -142,11 +145,13 @@ class NewRegistrationTokenRestServlet(RestServlet): length = body.get("length", 16) if not isinstance(length, int): raise SynapseError( - 400, "length must be an integer", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "length must be an integer", + Codes.INVALID_PARAM, ) if not (0 < length <= 64): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "length must be greater than zero and not greater than 64", Codes.INVALID_PARAM, ) @@ -162,7 +167,7 @@ class NewRegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -170,11 +175,15 @@ class NewRegistrationTokenRestServlet(RestServlet): expiry_time = body.get("expiry_time", None) if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) created = await self.store.create_registration_token( @@ -182,7 +191,9 @@ class NewRegistrationTokenRestServlet(RestServlet): ) if not created: raise SynapseError( - 400, f"Token already exists: {token}", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + f"Token already exists: {token}", + Codes.INVALID_PARAM, ) resp = { @@ -192,7 +203,7 @@ class NewRegistrationTokenRestServlet(RestServlet): "completed": 0, "expiry_time": expiry_time, } - return 200, resp + return HTTPStatus.OK, resp class RegistrationTokenRestServlet(RestServlet): @@ -247,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -261,7 +271,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: """Update a registration token.""" @@ -277,7 +287,7 @@ class RegistrationTokenRestServlet(RestServlet): or (isinstance(uses_allowed, int) and uses_allowed >= 0) ): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "uses_allowed must be a non-negative integer or null", Codes.INVALID_PARAM, ) @@ -287,11 +297,15 @@ class RegistrationTokenRestServlet(RestServlet): expiry_time = body["expiry_time"] if not isinstance(expiry_time, (int, type(None))): raise SynapseError( - 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must be an integer or null", + Codes.INVALID_PARAM, ) if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): raise SynapseError( - 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "expiry_time must not be in the past", + Codes.INVALID_PARAM, ) new_attributes["expiry_time"] = expiry_time @@ -307,7 +321,7 @@ class RegistrationTokenRestServlet(RestServlet): if token_info is None: raise NotFoundError(f"No such registration token: {token}") - return 200, token_info + return HTTPStatus.OK, token_info async def on_DELETE( self, request: SynapseRequest, token: str @@ -316,6 +330,6 @@ class RegistrationTokenRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if await self.store.delete_registration_token(token): - return 200, {} + return HTTPStatus.OK, {} raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a2f4edebb8..17c6df1cc8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -46,6 +46,139 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class RoomRestV2Servlet(RestServlet): + """Delete a room from server asynchronously with a background task. + + It is a combination and improvement of shutdown and purge room. + + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto- + joined to the new room. + + If 'purge' is true, it will remove all traces of a room from the database. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self._pagination_handler = hs.get_pagination_handler() + + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + delete_id = self._pagination_handler.start_shutdown_and_purge_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + purge=purge, + force_purge=force_purge, + ) + + return HTTPStatus.OK, {"delete_id": delete_id} + + +class DeleteRoomStatusByRoomIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) + if delete_ids is None: + raise NotFoundError("No delete task for room_id '%s' found" % room_id) + + response = [] + for delete_id in delete_ids: + delete = self._pagination_handler.get_delete_status(delete_id) + if delete: + response += [ + { + "delete_id": delete_id, + **delete.asdict(), + } + ] + return HTTPStatus.OK, {"results": cast(JsonDict, response)} + + +class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): + """Get the status of the delete room background task.""" + + PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._pagination_handler = hs.get_pagination_handler() + + async def on_GET( + self, request: SynapseRequest, delete_id: str + ) -> Tuple[int, JsonDict]: + + await assert_requester_is_admin(self._auth, request) + + delete_status = self._pagination_handler.get_delete_status(delete_id) + if delete_status is None: + raise NotFoundError("delete id '%s' not found" % delete_id) + + return HTTPStatus.OK, cast(JsonDict, delete_status.asdict()) + + class ListRoomRestServlet(RestServlet): """ List all rooms that are known to the homeserver. Results are returned @@ -60,40 +193,22 @@ class ListRoomRestServlet(RestServlet): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - 400, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "search_term cannot be an empty string", errcode=Codes.INVALID_PARAM, ) @@ -101,7 +216,9 @@ class ListRoomRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) reverse_order = True if direction == "b" else False @@ -133,7 +250,7 @@ class ListRoomRestServlet(RestServlet): else: response["prev_batch"] = 0 - return 200, response + return HTTPStatus.OK, response class RoomRestServlet(RestServlet): @@ -157,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -178,7 +294,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -254,7 +370,7 @@ class RoomRestServlet(RestServlet): # See https://github.com/python/mypy/issues/4976#issuecomment-579883622 # for some discussion on why this is necessary. Either way, # `ret` is an opaque dictionary blob as far as the rest of the app cares. - return 200, cast(JsonDict, ret) + return HTTPStatus.OK, cast(JsonDict, ret) class RoomMembersRestServlet(RestServlet): @@ -262,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -281,7 +396,7 @@ class RoomMembersRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret = {"members": members, "total": len(members)} - return 200, ret + return HTTPStatus.OK, ret class RoomStateRestServlet(RestServlet): @@ -289,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -301,8 +415,7 @@ class RoomStateRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -311,28 +424,22 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events( - events.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_aggregations=False, - ) + room_state = await self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} - return 200, ret + return HTTPStatus.OK, ret class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") + PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -348,8 +455,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -395,7 +505,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ratelimit=False, ) - return 200, {"room_id": room_id} + return HTTPStatus.OK, {"room_id": room_id} class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): @@ -410,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -436,7 +545,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Figure out which local users currently have power in the room, if any. room_state = await self.state_handler.get_current_state(room_id) if not room_state: - raise SynapseError(400, "Server not in room") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") create_event = room_state[(EventTypes.Create, "")] power_levels = room_state.get((EventTypes.PowerLevels, "")) @@ -450,7 +559,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_users.sort(key=lambda user: user_power[user]) if not admin_users: - raise SynapseError(400, "No local admin user in room") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "No local admin user in room" + ) admin_user_id = None @@ -467,7 +578,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if not admin_user_id: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -478,7 +589,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No local admin user in room", ) @@ -507,7 +618,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): except AuthError: # The admin user we found turned out not to have enough power. raise SynapseError( - 400, "No local admin user in room with power to update power levels." + HTTPStatus.BAD_REQUEST, + "No local admin user in room with power to update power levels.", ) # Now we check if the user we're granting admin rights to is already in @@ -521,7 +633,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): ) if is_joined: - return 200, {} + return HTTPStatus.OK, {} join_rules = room_state.get((EventTypes.JoinRules, "")) is_public = False @@ -529,7 +641,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC if is_public: - return 200, {} + return HTTPStatus.OK, {} await self.room_member_handler.update_membership( fake_requester, @@ -538,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): action=Membership.INVITE, ) - return 200, {} + return HTTPStatus.OK, {} class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): @@ -553,35 +665,32 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return 200, {"deleted": deleted_count} + return HTTPStatus.OK, {"deleted": deleted_count} async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return 200, {"count": len(extremities), "results": extremities} + return HTTPStatus.OK, {"count": len(extremities), "results": extremities} class RoomEventContextServlet(RestServlet): @@ -630,7 +739,9 @@ class RoomEventContextServlet(RestServlet): ) if not results: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError( + HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND + ) time_now = self.clock.time_msec() results["events_before"] = await self._event_serializer.serialize_events( @@ -643,10 +754,70 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_aggregations=False, + results["state"], time_now ) - return 200, results + return HTTPStatus.OK, results + + +class BlockRoomRestServlet(RestServlet): + """ + Manage blocking of rooms. + On PUT: Add or remove a room from blocking list. + On GET: Get blocking status of room and user who has blocked this room. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + blocked_by = await self._store.room_is_blocked_by(room_id) + # Test `not None` if `user_id` is an empty string + # if someone add manually an entry in database + if blocked_by is not None: + response = {"block": True, "user_id": blocked_by} + else: + response = {"block": False} + + return HTTPStatus.OK, response + + async def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester.user) + + content = parse_json_object_from_request(request) + + if not RoomID.is_valid(room_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) + ) + + assert_params_in_dict(content, ["block"]) + block = content.get("block") + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean.", + Codes.BAD_JSON, + ) + + if block: + await self._store.block_room(room_id, requester.user.to_string()) + else: + await self._store.unblock_room(room_id) + + return HTTPStatus.OK, {"block": block} diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 19f84f33f2..15da9cd881 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes @@ -51,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet): # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). if not self.server_notices_manager.is_enabled(): - raise SynapseError(400, "Server notices are not enabled on this server") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server" + ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Server notices can only be sent to local users") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" + ) if not await self.admin_handler.get_user(target_user): raise NotFoundError("User not found") @@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet): txn_id=txn_id, ) - return 200, {"event_id": event.event_id} + return HTTPStatus.OK, {"event_id": event.event_id} def on_PUT( self, request: SynapseRequest, txn_id: str diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 948de94ccd..7a6546372e 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError @@ -36,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -44,24 +44,21 @@ class UserMediaStatisticsRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - 400, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -69,7 +66,7 @@ class UserMediaStatisticsRestServlet(RestServlet): limit = parse_integer(request, "limit", default=100) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -77,7 +74,7 @@ class UserMediaStatisticsRestServlet(RestServlet): from_ts = parse_integer(request, "from_ts", default=0) if from_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -86,13 +83,13 @@ class UserMediaStatisticsRestServlet(RestServlet): if until_ts is not None: if until_ts < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if until_ts <= from_ts: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter until_ts must be greater than from_ts.", errcode=Codes.INVALID_PARAM, ) @@ -100,7 +97,7 @@ class UserMediaStatisticsRestServlet(RestServlet): search_term = parse_string(request, "search_term") if search_term == "": raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter search_term cannot be an empty string.", errcode=Codes.INVALID_PARAM, ) @@ -108,7 +105,9 @@ class UserMediaStatisticsRestServlet(RestServlet): direction = parse_string(request, "dir", default="f") if direction not in ("f", "b"): raise SynapseError( - 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + HTTPStatus.BAD_REQUEST, + "Unknown direction: %s" % (direction,), + errcode=Codes.INVALID_PARAM, ) users_media, total = await self.store.get_users_media_usage_paginate( @@ -118,4 +117,4 @@ class UserMediaStatisticsRestServlet(RestServlet): if (start + limit) < total: ret["next_token"] = start + len(users_media) - return 200, ret + return HTTPStatus.OK, ret diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf1472967..5353dc3682 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index d14fafbbc9..db678da4cf 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -79,14 +78,14 @@ class UsersRestServletV2(RestServlet): if start < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter from must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) if limit < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Query parameter limit must be a string representing a positive integer.", errcode=Codes.INVALID_PARAM, ) @@ -122,11 +121,11 @@ class UsersRestServletV2(RestServlet): if (start + limit) < total: ret["next_token"] = str(start + len(users)) - return 200, ret + return HTTPStatus.OK, ret class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -172,14 +171,14 @@ class UserRestServletV2(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) if not ret: raise NotFoundError("User not found") - return 200, ret + return HTTPStatus.OK, ret async def on_PUT( self, request: SynapseRequest, user_id: str @@ -191,7 +190,10 @@ class UserRestServletV2(RestServlet): body = parse_json_object_from_request(request) if not self.hs.is_mine(target_user): - raise SynapseError(400, "This endpoint can only be used with local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "This endpoint can only be used with local users", + ) user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() @@ -210,7 +212,7 @@ class UserRestServletV2(RestServlet): user_type = body.get("user_type", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") set_admin_to = body.get("admin", False) if not isinstance(set_admin_to, bool): @@ -223,11 +225,13 @@ class UserRestServletV2(RestServlet): password = body.get("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): - raise SynapseError(400, "'deactivated' parameter is not of type boolean") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" + ) # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: @@ -282,7 +286,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -293,7 +299,9 @@ class UserRestServletV2(RestServlet): if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "You may not demote yourself." + ) await self.store.set_server_admin(target_user, set_admin_to) @@ -319,7 +327,8 @@ class UserRestServletV2(RestServlet): and self.auth_handler.can_change_password() ): raise SynapseError( - 400, "Must provide a password to re-activate an account." + HTTPStatus.BAD_REQUEST, + "Must provide a password to re-activate an account.", ) await self.deactivate_account_handler.activate_account( @@ -332,7 +341,7 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) assert user is not None - return 200, user + return HTTPStatus.OK, user else: # create user displayname = body.get("displayname", None) @@ -381,7 +390,9 @@ class UserRestServletV2(RestServlet): user_id, ) except ExternalIDReuseException: - raise SynapseError(409, "External id is already in use.") + raise SynapseError( + HTTPStatus.CONFLICT, "External id is already in use." + ) if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -402,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -429,51 +440,61 @@ class UserRegisterServlet(RestServlet): nonce = secrets.token_hex(64) self.nonces[nonce] = int(self.reactor.seconds()) - return 200, {"nonce": nonce} + return HTTPStatus.OK, {"nonce": nonce} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() if not self.hs.config.registration.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled" + ) body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "nonce must be specified", + errcode=Codes.BAD_JSON, + ) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError(400, "unrecognised nonce") + raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "username must be specified", + errcode=Codes.BAD_JSON, ) else: if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") username = body["username"].encode("utf-8") if b"\x00" in username: - raise SynapseError(400, "Invalid username") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username") if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON + HTTPStatus.BAD_REQUEST, + "password must be specified", + errcode=Codes.BAD_JSON, ) else: password = body["password"] if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_bytes = password.encode("utf-8") if b"\x00" in password_bytes: - raise SynapseError(400, "Invalid password") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -482,10 +503,12 @@ class UserRegisterServlet(RestServlet): displayname = body.get("displayname", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") if "mac" not in body: - raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON + ) got_mac = body["mac"] @@ -507,7 +530,7 @@ class UserRegisterServlet(RestServlet): want_mac = want_mac_builder.hexdigest() if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): - raise SynapseError(403, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.register import RegisterRestServlet @@ -524,7 +547,7 @@ class UserRegisterServlet(RestServlet): ) result = await register._create_registration_details(user_id, body) - return 200, result + return HTTPStatus.OK, result class WhoisRestServlet(RestServlet): @@ -537,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -551,16 +574,16 @@ class WhoisRestServlet(RestServlet): if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only whois a local user") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) - return 200, ret + return HTTPStatus.OK, ret class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -575,7 +598,9 @@ class DeactivateAccountRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) if not self.is_mine(UserID.from_string(target_user_id)): - raise SynapseError(400, "Can only deactivate local users") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Can only deactivate local users" + ) if not await self.store.get_user_by_id(target_user_id): raise NotFoundError("User not found") @@ -597,14 +622,13 @@ class DeactivateAccountRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - return 200, {"id_server_unbind_result": id_server_unbind_result} + return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result} class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -620,7 +644,7 @@ class AccountValidityRenewServlet(RestServlet): if "user_id" not in body: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "Missing property 'user_id' in the request body", ) @@ -631,7 +655,7 @@ class AccountValidityRenewServlet(RestServlet): ) res = {"expiration_ts": expiration_ts} - return 200, res + return HTTPStatus.OK, res class ResetPasswordRestServlet(RestServlet): @@ -648,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -678,7 +701,7 @@ class ResetPasswordRestServlet(RestServlet): await self._set_password_handler.set_password( target_user_id, new_password_hash, logout_devices, requester ) - return 200, {} + return HTTPStatus.OK, {} class SearchUsersRestServlet(RestServlet): @@ -692,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -712,16 +735,16 @@ class SearchUsersRestServlet(RestServlet): # To allow all users to get the users list # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") + # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") + if not self.is_mine(target_user): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) logger.info("term: %s ", term) ret = await self.store.search_users(term) - return 200, ret + return HTTPStatus.OK, ret class UserAdminServlet(RestServlet): @@ -753,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -764,12 +787,15 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) is_admin = await self.store.is_server_admin(target_user) - return 200, {"admin": is_admin} + return HTTPStatus.OK, {"admin": is_admin} async def on_PUT( self, request: SynapseRequest, user_id: str @@ -784,17 +810,20 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Only local users can be admins of this homeserver") + if not self.is_mine(target_user): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Only local users can be admins of this homeserver", + ) set_admin_to = bool(body["admin"]) if target_user == auth_user and not set_admin_to: - raise SynapseError(400, "You may not demote yourself.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.") await self.store.set_server_admin(target_user, set_admin_to) - return 200, {} + return HTTPStatus.OK, {} class UserMembershipRestServlet(RestServlet): @@ -802,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -816,7 +845,7 @@ class UserMembershipRestServlet(RestServlet): room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} - return 200, ret + return HTTPStatus.OK, ret class PushersRestServlet(RestServlet): @@ -845,7 +874,7 @@ class PushersRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -854,7 +883,10 @@ class PushersRestServlet(RestServlet): filtered_pushers = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} + return HTTPStatus.OK, { + "pushers": filtered_pushers, + "total": len(filtered_pushers), + } class UserTokenRestServlet(RestServlet): @@ -874,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -886,30 +918,36 @@ class UserTokenRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be logged in as") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" + ) body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") if valid_until_ms and not isinstance(valid_until_ms, int): - raise SynapseError(400, "'valid_until_ms' parameter must be an int") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" + ) if auth_user.to_string() == user_id: - raise SynapseError(400, "Cannot use admin API to login as self") + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self" + ) - token = await self.auth_handler.get_access_token_for_user_id( + token = await self.auth_handler.create_access_token_for_user_id( user_id=auth_user.to_string(), device_id=None, valid_until_ms=valid_until_ms, puppets_user_id=user_id, ) - return 200, {"access_token": token} + return HTTPStatus.OK, {"access_token": token} class ShadowBanRestServlet(RestServlet): - """An admin API for shadow-banning a user. + """An admin API for controlling whether a user is shadow-banned. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. @@ -917,33 +955,57 @@ class ShadowBanRestServlet(RestServlet): Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. - Example: + Example of shadow-banning a user: POST /_synapse/admin/v1/users/@test:example.com/shadow_ban {} 200 OK {} + + Example of removing a user from being shadow-banned: + + DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be shadow-banned") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) await self.store.set_shadow_banned(UserID.from_string(user_id), True) - return 200, {} + return HTTPStatus.OK, {} + + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" + ) + + await self.store.set_shadow_banned(UserID.from_string(user_id), False) + + return HTTPStatus.OK, {} class RateLimitRestServlet(RestServlet): @@ -962,20 +1024,20 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only look up local users") + if not self.is_mine_id(user_id): + raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -996,15 +1058,17 @@ class RateLimitRestServlet(RestServlet): else: ret = {} - return 200, ret + return HTTPStatus.OK, ret async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -1016,14 +1080,14 @@ class RateLimitRestServlet(RestServlet): if not isinstance(messages_per_second, int) or messages_per_second < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) if not isinstance(burst_count, int) or burst_count < 0: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), errcode=Codes.INVALID_PARAM, ) @@ -1039,19 +1103,21 @@ class RateLimitRestServlet(RestServlet): "burst_count": ratelimit.burst_count, } - return 200, ret + return HTTPStatus.OK, ret async def on_DELETE( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Only local users can be ratelimited") + if not self.is_mine_id(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" + ) if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") await self.store.delete_ratelimit_for_user(user_id) - return 200, {} + return HTTPStatus.OK, {} diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index a0971ce994..b4cb90cb76 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) def client_patterns( path_regex: str, - releases: Iterable[int] = (0,), + releases: Iterable[str] = ("r0", "v3"), unstable: bool = True, v1: bool = False, ) -> Iterable[Pattern]: @@ -52,7 +52,7 @@ def client_patterns( v1_prefix = CLIENT_API_PREFIX + "/api/v1" patterns.append(re.compile("^" + v1_prefix + path_regex)) for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) + new_prefix = CLIENT_API_PREFIX + f"/{release}" patterns.append(re.compile("^" + new_prefix + path_regex)) return patterns diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7281b2ee29..730c18f08f 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet): } """ - PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",)) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d49a647b03..f9994658c4 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -14,7 +14,17 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) from typing_extensions import TypedDict @@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, assert_params_in_dict, - parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -61,8 +70,9 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" - APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" + APPSERVICE_TYPE = "m.login.application_service" + APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -71,6 +81,7 @@ class LoginRestServlet(RestServlet): # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences @@ -79,7 +90,9 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._refresh_tokens_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self.auth = hs.get_auth() @@ -143,23 +156,29 @@ class LoginRestServlet(RestServlet): flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE_UNSTABLE}) return 200, {"flows": flows} async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) - if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per - # MSC2918 - should_issue_refresh_token = parse_boolean( - request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False - ) - else: - should_issue_refresh_token = False + # Check to see if the client requested a refresh token. + client_requested_refresh_token = login_submission.get( + LoginRestServlet.REFRESH_TOKEN_PARAM, False + ) + if not isinstance(client_requested_refresh_token, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_token + ) try: - if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + if login_submission["type"] in ( + LoginRestServlet.APPSERVICE_TYPE, + LoginRestServlet.APPSERVICE_TYPE_UNSTABLE, + ): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): @@ -283,6 +302,7 @@ class LoginRestServlet(RestServlet): ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -298,10 +318,10 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. + auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -334,6 +354,7 @@ class LoginRestServlet(RestServlet): initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -379,6 +400,7 @@ class LoginRestServlet(RestServlet): self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( @@ -408,7 +430,7 @@ class LoginRestServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - user = payload.get("sub", None) + user = payload.get(self.jwt_subject_claim, None) if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) @@ -440,14 +462,15 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = client_patterns( - "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True - ) + PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.registration.access_token_lifetime + self.refreshable_access_token_lifetime = ( + hs.config.registration.refreshable_access_token_lifetime + ) + self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -457,27 +480,40 @@ class RefreshTokenServlet(RestServlet): if not isinstance(token, str): raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime - access_token, refresh_token = await self._auth_handler.refresh_token( - token, valid_until_ms - ) - expires_in_ms = valid_until_ms - self._clock.time_msec() - return ( - 200, - { - "access_token": access_token, - "refresh_token": refresh_token, - "expires_in_ms": expires_in_ms, - }, + now = self._clock.time_msec() + access_valid_until_ms = None + if self.refreshable_access_token_lifetime is not None: + access_valid_until_ms = now + self.refreshable_access_token_lifetime + refresh_valid_until_ms = None + if self.refresh_token_lifetime is not None: + refresh_valid_until_ms = now + self.refresh_token_lifetime + + ( + access_token, + refresh_token, + actual_access_token_expiry, + ) = await self._auth_handler.refresh_token( + token, access_valid_until_ms, refresh_valid_until_ms ) + response: Dict[str, Union[str, int]] = { + "access_token": access_token, + "refresh_token": refresh_token, + } + + # expires_in_ms is only present if the token expires + if actual_access_token_expiry is not None: + response["expires_in_ms"] = actual_access_token_expiry - now + + return 200, response + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ re.compile( "^" + CLIENT_API_PREFIX - + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" + + "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" ) ] @@ -556,7 +592,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.registration.access_token_lifetime is not None: + if hs.config.registration.refreshable_access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c6..b12a332776 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) notif_event_ids = [pa["event_id"] for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6f..f51be511d1 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): @@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet): if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 2b25b9aad6..b24ad2d1be 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer @@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index bf3cb34146..8b56c76aed 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, - parse_boolean, parse_json_object_from_request, parse_string, ) @@ -420,7 +419,9 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None + self._refresh_tokens_enabled = ( + hs.config.registration.refreshable_access_token_lifetime is not None + ) self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -444,14 +445,15 @@ class RegisterRestServlet(RestServlet): f"Do not understand membership kind: {kind}", ) - if self._msc2918_enabled: - # Check if this registration should also issue a refresh token, as - # per MSC2918 - should_issue_refresh_token = parse_boolean( - request, name="org.matrix.msc2918.refresh_token", default=False - ) - else: - should_issue_refresh_token = False + # Check if the clients wishes for this registration to issue a refresh + # token. + client_requested_refresh_tokens = body.get("refresh_token", False) + if not isinstance(client_requested_refresh_tokens, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_tokens + ) # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 184cfbe196..fc4e6921c5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,18 +224,14 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original - # event because we want the content before relations were applied to - # it. + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) - # Similarly, we don't allow relations to be applied to relations, so we - # return the original relations without any aggregations on top of them - # here. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False - ) + # The relations returned for the requested event do include their + # bundled aggregations. + serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() return_value["chunk"] = serialized_events diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 03a353d53c..f48e2e6ca2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_aggregations=False, + results["state"], time_now ) return 200, results @@ -1070,6 +1067,62 @@ def register_txn_path( ) +class TimestampLookupRestServlet(RestServlet): + """ + API endpoint to fetch the `event_id` of the closest event to the given + timestamp (`ts` query parameter) in the given direction (`dir` query + parameter). + + Useful for cases like jump to date so you can start paginating messages from + a given date in the archive. + + `ts` is a timestamp in milliseconds where we will find the closest event in + the given direction. + + `dir` can be `f` or `b` to indicate forwards and backwards in time from the + given timestamp. + + GET /_matrix/client/unstable/org.matrix.msc3030/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction> + { + "event_id": ... + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3030" + "/rooms/(?P<room_id>[^/]*)/timestamp_to_event$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._store = hs.get_datastore() + self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await self._auth.check_user_in_room(room_id, requester.user.to_string()) + + timestamp = parse_integer(request, "ts", required=True) + direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) + + ( + event_id, + origin_server_ts, + ) = await self.timestamp_lookup_handler.get_event_for_timestamp( + requester, room_id, timestamp, direction + ) + + return 200, { + "event_id": event_id, + "origin_server_ts": origin_server_ts, + } + + class RoomSpaceSummaryRestServlet(RestServlet): PATTERNS = ( re.compile( @@ -1140,7 +1193,7 @@ class RoomSpaceSummaryRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): PATTERNS = ( re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" + "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" "/rooms/(?P<room_id>[^/]*)/hierarchy$" ), ) @@ -1168,7 +1221,7 @@ class RoomHierarchyRestServlet(RestServlet): ) return 200, await self._room_summary_handler.get_room_hierarchy( - requester.user.to_string(), + requester, room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), max_depth=max_depth, @@ -1239,6 +1292,8 @@ def register_servlets( RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) + if hs.config.experimental.msc3030_enabled: + TimestampLookupRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8c0fdb1940..88e4f5e063 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet): return self._event_serializer.serialize_events( events, time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_aggregations=False, + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundle_aggregations=room.timeline.limited, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 014fa893d6..9b40fd8a6c 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name logger = logging.getLogger(__name__) @@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ try: # The type on postpath seems incorrect in Twisted 21.2.0. postpath: List[bytes] = request.postpath # type: ignore @@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: server_name = server_name_bytes.decode("utf-8") media_id = media_id_bytes.decode("utf8") + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + file_name = None if len(postpath) > 2: try: diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index bec77088ee..1f6441c412 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -16,7 +16,8 @@ import functools import os import re -from typing import Any, Callable, List, TypeVar, cast +import string +from typing import Any, Callable, List, TypeVar, Union, cast NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") @@ -37,6 +38,113 @@ def _wrap_in_base_path(func: F) -> F: return cast(F, _wrapped) +GetPathMethod = TypeVar( + "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] +) + + +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: + """Wraps a path-returning method to check that the returned path(s) do not escape + the media store directory. + + The path-returning method may return either a single path, or a list of paths. + + The check is not expected to ever fail, unless `func` is missing a call to + `_validate_path_component`, or `_validate_path_component` is buggy. + + Args: + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. + + Returns: + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. + """ + + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner + + +ALLOWED_CHARACTERS = set( + string.ascii_letters + + string.digits + + "_-" + + ".[]:" # Domain names, IPv6 addresses and ports in server names +) +FORBIDDEN_NAMES = { + "", + os.path.curdir, # "." for the current platform + os.path.pardir, # ".." for the current platform +} + + +def _validate_path_component(name: str) -> str: + """Checks that the given string can be safely used as a path component + + Args: + name: The path component to check. + + Returns: + The path component if valid. + + Raises: + ValueError: If `name` cannot be safely used as a path component. + """ + if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: + raise ValueError(f"Invalid path component: {name!r}") + + return name + + class MediaFilePaths: """Describes where files are stored on disk. @@ -47,23 +155,45 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path - + self.normalized_base_path = os.path.normpath(self.base_path) + + # Refuse to initialize if paths cannot be validated correctly for the current + # platform. + assert os.path.sep not in ALLOWED_CHARACTERS + assert os.path.altsep not in ALLOWED_CHARACTERS + # On Windows, paths have all sorts of weirdness which `_validate_path_component` + # does not consider. In any case, the remote media store can't work correctly + # for certain homeservers there, since ":"s aren't allowed in paths. + assert os.name == "posix" + + @_wrap_with_jail_check(relative=True) def local_media_filepath_rel(self, media_id: str) -> str: - return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) + return os.path.join( + "local_content", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + @_wrap_with_jail_check(relative=True) def local_media_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), ) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + @_wrap_with_jail_check(relative=False) def local_media_thumbnail_dir(self, media_id: str) -> str: """ Retrieve the local store path of thumbnails of a given media_id @@ -76,18 +206,24 @@ class MediaFilePaths: return os.path.join( self.base_path, "local_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ) + @_wrap_with_jail_check(relative=True) def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( - "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] + "remote_content", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), ) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel( self, server_name: str, @@ -101,11 +237,11 @@ class MediaFilePaths: file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], - file_name, + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), ) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) @@ -113,6 +249,7 @@ class MediaFilePaths: # Legacy path that was used to store thumbnails previously. # Should be removed after some time, when most of the thumbnails are stored # using the new path. + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str ) -> str: @@ -120,43 +257,67 @@ class MediaFilePaths: file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], - file_name, + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), ) + @_wrap_with_jail_check(relative=False) def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", - server_name, - file_id[0:2], - file_id[2:4], - file_id[4:], + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), ) + @_wrap_with_jail_check(relative=True) def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join("url_cache", media_id[:10], media_id[11:]) + return os.path.join( + "url_cache", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) else: - return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) + return os.path.join( + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + @_wrap_with_jail_check(relative=False) def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): - return [os.path.join(self.base_path, "url_cache", media_id[:10])] + return [ + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[:10]) + ) + ] else: return [ - os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), - os.path.join(self.base_path, "url_cache", media_id[0:2]), + os.path.join( + self.base_path, + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[0:2]) + ), ] + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -168,37 +329,46 @@ class MediaFilePaths: if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - "url_cache_thumbnails", media_id[:10], media_id[11:], file_name + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + _validate_path_component(file_name), ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], - file_name, + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), ) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:]) + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ) url_cache_thumbnail_directory = _wrap_in_base_path( url_cache_thumbnail_directory_rel ) + @_wrap_with_jail_check(relative=False) def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form <DATE><RANDOM_STRING> @@ -206,21 +376,35 @@ class MediaFilePaths: if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), ), - os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]), ] else: return [ os.path.join( self.base_path, "url_cache_thumbnails", - media_id[0:2], - media_id[2:4], - media_id[4:], + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), ), os.path.join( - self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), ), - os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]), ] diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8ca97b5b18..054f3c296d 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -231,7 +231,7 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url: str, user: str, ts: int) -> bytes: + async def _do_preview(self, url: str, user: UserID, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: @@ -360,7 +360,7 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - async def _download_url(self, url: str, user: str) -> MediaInfo: + async def _download_url(self, url: str, user: UserID) -> MediaInfo: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): ) async def _precache_image_url( - self, user: str, media_info: MediaInfo, og: JsonDict + self, user: UserID, media_info: MediaInfo, og: JsonDict ) -> None: """ Pre-cache the image (if one exists) for posterity diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 46701a8b83..5e17664b5b 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -101,8 +101,8 @@ class Thumbnailer: fits within the given rectangle:: (w_in / h_in) = (w_out / h_out) - w_out = min(w_max, h_max * (w_in / h_in)) - h_out = min(h_max, w_max * (h_in / w_in)) + w_out = max(min(w_max, h_max * (w_in / h_in)), 1) + h_out = max(min(h_max, w_max * (h_in / w_in)), 1) Args: max_width: The largest possible width. @@ -110,9 +110,9 @@ class Thumbnailer: """ if max_width * self.height < max_height * self.width: - return max_width, (max_width * self.height) // self.width + return max_width, max((max_width * self.height) // self.width, 1) else: - return (max_height * self.width) // self.height, max_height + return max((max_height * self.width) // self.height, 1), max_height def _resize(self, width: int, height: int) -> Image.Image: # 1-bit or 8-bit color palette images need converting to RGB diff --git a/synapse/server.py b/synapse/server.py index 013a7bacaa..185e40e4da 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,9 +33,10 @@ from typing import ( cast, ) -import twisted.internet.tcp +from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.tcp import Port from twisted.web.iweb import IPolicyForHTTPS -from twisted.web.resource import IResource +from twisted.web.resource import Resource from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -96,6 +97,7 @@ from synapse.handlers.room import ( RoomContextHandler, RoomCreationHandler, RoomShutdownHandler, + TimestampLookupHandler, ) from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler @@ -206,7 +208,7 @@ class HomeServer(metaclass=abc.ABCMeta): Attributes: config (synapse.config.homeserver.HomeserverConfig): - _listening_services (list[twisted.internet.tcp.Port]): TCP ports that + _listening_services (list[Port]): TCP ports that we are listening on to provide HTTP services. """ @@ -225,6 +227,8 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() + tls_server_context_factory: Optional[IOpenSSLContextFactory] + def __init__( self, hostname: str, @@ -247,7 +251,7 @@ class HomeServer(metaclass=abc.ABCMeta): # the key we use to sign events and requests self.signing_key = config.key.signing_key[0] self.config = config - self._listening_services: List[twisted.internet.tcp.Port] = [] + self._listening_services: List[Port] = [] self.start_time: Optional[int] = None self._instance_id = random_string(5) @@ -257,10 +261,10 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores: Optional[Databases] = None - self._module_web_resources: Dict[str, IResource] = {} + self._module_web_resources: Dict[str, Resource] = {} self._module_web_resources_consumed = False - def register_module_web_resource(self, path: str, resource: IResource): + def register_module_web_resource(self, path: str, resource: Resource): """Allows a module to register a web resource to be served at the given path. If multiple modules register a resource for the same path, the module that @@ -726,6 +730,10 @@ class HomeServer(metaclass=abc.ABCMeta): return RoomContextHandler(self) @cache_in_self + def get_timestamp_lookup_handler(self) -> TimestampLookupHandler: + return TimestampLookupHandler(self) + + @cache_in_self def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1605411b00..446204dbe5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -764,7 +764,7 @@ class StateResolutionStore: store: "DataStore" def get_events( - self, event_ids: Iterable[str], allow_rejected: bool = False + self, event_ids: Collection[str], allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 6edadea550..499a328201 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,6 +17,7 @@ import logging from typing import ( Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -44,7 +45,7 @@ async def resolve_events_with_store( room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], + state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], ) -> StateMap[str]: """ Args: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0623da9aa1..3056e64ff5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.storage.types import Connection -from synapse.types import StreamToken, get_domain_from_id +from synapse.types import get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: @@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta): self, stream_name: str, instance_name: str, - token: StreamToken, + token: int, rows: Iterable[Any], ) -> None: pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b9a8ca997e..d64910aded 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + Awaitable, + Callable, + Dict, + Iterable, + Optional, +) + +import attr from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.types import Connection from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import Clock, json_encoder from . import engines @@ -28,6 +38,45 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]] +DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] +MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _BackgroundUpdateHandler: + """A handler for a given background update. + + Attributes: + callback: The function to call to make progress on the background + update. + oneshot: Wether the update is likely to happen all in one go, ignoring + the supplied target duration, e.g. index creation. This is used by + the update controller to help correctly schedule the update. + """ + + callback: Callable[[JsonDict, int], Awaitable[int]] + oneshot: bool = False + + +class _BackgroundUpdateContextManager: + BACKGROUND_UPDATE_INTERVAL_MS = 1000 + BACKGROUND_UPDATE_DURATION_MS = 100 + + def __init__(self, sleep: bool, clock: Clock): + self._sleep = sleep + self._clock = clock + + async def __aenter__(self) -> int: + if self._sleep: + await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000) + + return self.BACKGROUND_UPDATE_DURATION_MS + + async def __aexit__(self, *exc) -> None: + pass + + class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" @@ -82,22 +131,24 @@ class BackgroundUpdater: process and autotuning the batch size. """ - MINIMUM_BACKGROUND_BATCH_SIZE = 100 + MINIMUM_BACKGROUND_BATCH_SIZE = 1 DEFAULT_BACKGROUND_BATCH_SIZE = 100 - BACKGROUND_UPDATE_INTERVAL_MS = 1000 - BACKGROUND_UPDATE_DURATION_MS = 100 def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database + self._database_name = database.name() + # if a background update is currently running, its name. self._current_background_update: Optional[str] = None + self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None + self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None + self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None + self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} - self._background_update_handlers: Dict[ - str, Callable[[JsonDict, int], Awaitable[int]] - ] = {} + self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {} self._all_done = False # Whether we're currently running updates @@ -107,6 +158,83 @@ class BackgroundUpdater: # enable/disable background updates via the admin API. self.enabled = True + def register_update_controller_callbacks( + self, + on_update: ON_UPDATE_CALLBACK, + default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + ) -> None: + """Register callbacks from a module for each hook.""" + if self._on_update_callback is not None: + logger.warning( + "More than one module tried to register callbacks for controlling" + " background updates. Only the callbacks registered by the first module" + " (in order of appearance in Synapse's configuration file) that tried to" + " do so will be called." + ) + + return + + self._on_update_callback = on_update + + if default_batch_size is not None: + self._default_batch_size_callback = default_batch_size + + if min_batch_size is not None: + self._min_batch_size_callback = min_batch_size + + def _get_context_manager_for_update( + self, + sleep: bool, + update_name: str, + database_name: str, + oneshot: bool, + ) -> AsyncContextManager[int]: + """Get a context manager to run a background update with. + + If a module has registered a `update_handler` callback, use the context manager + it returns. + + Otherwise, returns a context manager that will return a default value, optionally + sleeping if needed. + + Args: + sleep: Whether we can sleep between updates. + update_name: The name of the update. + database_name: The name of the database the update is being run on. + oneshot: Whether the update will complete all in one go, e.g. index creation. + In such cases the returned target duration is ignored. + + Returns: + The target duration in milliseconds that the background update should run for. + + Note: this is a *target*, and an iteration may take substantially longer or + shorter. + """ + if self._on_update_callback is not None: + return self._on_update_callback(update_name, database_name, oneshot) + + return _BackgroundUpdateContextManager(sleep, self._clock) + + async def _default_batch_size(self, update_name: str, database_name: str) -> int: + """The batch size to use for the first iteration of a new background + update. + """ + if self._default_batch_size_callback is not None: + return await self._default_batch_size_callback(update_name, database_name) + + return self.DEFAULT_BACKGROUND_BATCH_SIZE + + async def _min_batch_size(self, update_name: str, database_name: str) -> int: + """A lower bound on the batch size of a new background update. + + Used to ensure that progress is always made. Must be greater than 0. + """ + if self._min_batch_size_callback is not None: + return await self._min_batch_size_callback(update_name, database_name) + + return self.MINIMUM_BACKGROUND_BATCH_SIZE + def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: """Returns the current background update, if any.""" @@ -122,6 +250,8 @@ class BackgroundUpdater: def start_doing_background_updates(self) -> None: if self.enabled: + # if we start a new background update, not all updates are done. + self._all_done = False run_as_background_process("background_updates", self.run_background_updates) async def run_background_updates(self, sleep: bool = True) -> None: @@ -133,13 +263,8 @@ class BackgroundUpdater: try: logger.info("Starting background schema updates") while self.enabled: - if sleep: - await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) - try: - result = await self.do_next_background_update( - self.BACKGROUND_UPDATE_DURATION_MS - ) + result = await self.do_next_background_update(sleep) except Exception: logger.exception("Error doing update") else: @@ -201,13 +326,15 @@ class BackgroundUpdater: return not update_exists - async def do_next_background_update(self, desired_duration_ms: float) -> bool: + async def do_next_background_update(self, sleep: bool = True) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. Args: - desired_duration_ms: How long we want to spend updating. + sleep: Whether to limit how quickly we run background updates or + not. + Returns: True if we have finished running all the background updates, otherwise False """ @@ -250,7 +377,19 @@ class BackgroundUpdater: self._current_background_update = upd["update_name"] - await self._do_background_update(desired_duration_ms) + # We have a background update to run, otherwise we would have returned + # early. + assert self._current_background_update is not None + update_info = self._background_update_handlers[self._current_background_update] + + async with self._get_context_manager_for_update( + sleep=sleep, + update_name=self._current_background_update, + database_name=self._database_name, + oneshot=update_info.oneshot, + ) as desired_duration_ms: + await self._do_background_update(desired_duration_ms) + return False async def _do_background_update(self, desired_duration_ms: float) -> int: @@ -258,7 +397,7 @@ class BackgroundUpdater: update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) - update_handler = self._background_update_handlers[update_name] + update_handler = self._background_update_handlers[update_name].callback performance = self._background_update_performance.get(update_name) @@ -271,9 +410,14 @@ class BackgroundUpdater: if items_per_ms is not None: batch_size = int(desired_duration_ms * items_per_ms) # Clamp the batch size so that we always make progress - batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE) + batch_size = max( + batch_size, + await self._min_batch_size(update_name, self._database_name), + ) else: - batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE + batch_size = await self._default_batch_size( + update_name, self._database_name + ) progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", @@ -292,6 +436,8 @@ class BackgroundUpdater: duration_ms = time_stop - time_start + performance.update(items_updated, duration_ms) + logger.info( "Running background update %r. Processed %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", @@ -304,8 +450,6 @@ class BackgroundUpdater: batch_size, ) - performance.update(items_updated, duration_ms) - return len(self._background_update_performance) def register_background_update_handler( @@ -329,7 +473,9 @@ class BackgroundUpdater: update_name: The name of the update that this code handles. update_handler: The function that does the update. """ - self._background_update_handlers[update_name] = update_handler + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + update_handler + ) def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. @@ -451,7 +597,9 @@ class BackgroundUpdater: await self._end_background_update(update_name) return 1 - self.register_background_update_handler(update_name, updater) + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + updater, oneshot=True + ) async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d4cab69ebf..0693d39006 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -188,7 +188,7 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. -_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]] +_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] R = TypeVar("R") @@ -235,7 +235,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any): + def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -247,7 +247,7 @@ class LoggingTransaction: self.after_callbacks.append((callback, args, kwargs)) def call_on_exception( - self, callback: Callable[..., None], *args: Any, **kwargs: Any + self, callback: Callable[..., object], *args: Any, **kwargs: Any ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 259cae5b37..9ff2d8d8c3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -123,9 +123,9 @@ class DataStore( RelationsStore, CensorEventsStore, UIAuthStore, + EventForwardExtremitiesStore, CacheInvalidationWorkerStore, ServerMetricsStore, - EventForwardExtremitiesStore, LockStore, SessionStore, ): @@ -154,6 +154,7 @@ class DataStore( db_conn, "local_group_updates", "stream_id" ) + self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index baec35ee27..4a883dc166 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore( A list of ApplicationServices, which may be empty. """ results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state}, ["as_id"] + "application_services_state", {"state": state.value}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore( desc="get_appservice_state", ) if result: - return result.get("state") + return ApplicationServiceState(result.get("state")) return None async def set_appservice_state( @@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore( state: The connectivity state to apply. """ await self.db_pool.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state} + "application_services_state", {"as_id": service.id}, {"state": state.value} ) async def create_appservice_txn( diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index eee07227ef..0f56e10220 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -13,12 +13,12 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder @@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) @wrap_as_background_process("_censor_redactions") - async def _censor_redactions(self): + async def _censor_redactions(self) -> None: """Censors all redactions older than the configured period that haven't been censored yet. @@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = json_encoder.encode( + pruned_json: Optional[str] = json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase updates.append((redaction_id, event_id, pruned_json)) - def _update_censor_txn(txn): + def _update_censor_txn(txn: LoggingTransaction) -> None: for redaction_id, event_id, pruned_json in updates: if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) @@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) - def _censor_event_txn(self, txn, event_id, pruned_json): + def _censor_event_txn( + self, txn: LoggingTransaction, event_id: str, pruned_json: str + ) -> None: """Censor an event by replacing its JSON in the event_json table with the provided pruned JSON. Args: - txn (LoggingTransaction): The database transaction. - event_id (str): The ID of the event to censor. - pruned_json (str): The pruned JSON + txn: The database transaction. + event_id: The ID of the event to censor. + pruned_json: The pruned JSON """ self.db_pool.simple_update_one_txn( txn, @@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # Try to retrieve the event's content from the database or the event cache. event = await self.get_event(event_id) - def delete_expired_event_txn(txn): + def delete_expired_event_txn(txn: LoggingTransaction) -> None: # Delete the expiry timestamp associated with this event from the database. self._delete_event_expiry_txn(txn, event_id) @@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase "delete_expired_event", delete_expired_event_txn ) - def _delete_event_expiry_txn(self, txn, event_id): + def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None: """Delete the expiry timestamp associated with an event ID without deleting the actual event. Args: - txn (LoggingTransaction): The transaction to use to perform the deletion. - event_id (str): The event ID to delete the associated expiry timestamp of. + txn: The transaction to use to perform the deletion. + event_id: The event ID to delete the associated expiry timestamp of. """ - return self.db_pool.simple_delete_txn( + self.db_pool.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index ae3afdd5d2..ab8766c75b 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache @@ -34,14 +43,21 @@ logger = logging.getLogger(__name__) class DeviceInboxWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. - self._last_device_delete_cache = ExpiringCache( + self._last_device_delete_cache: ExpiringCache[ + Tuple[str, Optional[str]], int + ] = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, @@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): self._instance_name in hs.config.worker.writers.to_device ) - self._device_inbox_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - stream_name="to_device", - instance_name=self._instance_name, - tables=[("device_inbox", "instance_name", "stream_id")], - sequence_name="device_inbox_sequence", - writers=hs.config.worker.writers.to_device, + self._device_inbox_id_gen: AbstractStreamIdGenerator = ( + MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="to_device", + instance_name=self._instance_name, + tables=[("device_inbox", "instance_name", "stream_id")], + sequence_name="device_inbox_sequence", + writers=hs.config.worker.writers.to_device, + ) ) else: self._can_write_to_device = True @@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: + # If replication is happening than postgres must be being used. + assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): @@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": f"deleted {count} messages for device", "count": count}) # Update the cache, ensuring that we only ever increase the value - last_deleted_stream_id = self._last_device_delete_cache.get( + updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0 ) self._last_device_delete_cache[(user_id, device_id)] = max( - last_deleted_stream_id, up_to_stream_id + updated_last_deleted_stream_id, up_to_stream_id ) return count @@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) async with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() + now_ms = self._clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) @@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) async with self._device_inbox_id_gen.get_next() as stream_id: - now_ms = self.clock.time_msec() + now_ms = self._clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, @@ -579,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox" REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" + REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -594,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - self.db_pool.updates.register_background_update_handler( - self.REMOVE_DELETED_DEVICES, - self._remove_deleted_devices_from_device_inbox, + # Used to be a background update that deletes all device_inboxes for deleted + # devices. + self.db_pool.updates.register_noop_background_update( + self.REMOVE_DELETED_DEVICES ) + # Used to be a background update that deletes all device_inboxes for hidden + # devices. + self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES) self.db_pool.updates.register_background_update_handler( - self.REMOVE_HIDDEN_DEVICES, - self._remove_hidden_devices_from_device_inbox, + self.REMOVE_DEAD_DEVICES_FROM_INBOX, + self._remove_dead_devices_from_device_inbox, ) async def _background_drop_index_device_inbox(self, progress, batch_size): @@ -616,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): return 1 - async def _remove_deleted_devices_from_device_inbox( - self, progress: JsonDict, batch_size: int + async def _remove_dead_devices_from_device_inbox( + self, + progress: JsonDict, + batch_size: int, ) -> int: - """A background update that deletes all device_inboxes for deleted devices. - - This should only need to be run once (when users upgrade to v1.47.0) + """A background update to remove devices that were either deleted or hidden from + the device_inbox table. Args: - progress: JsonDict used to store progress of this background update - batch_size: the maximum number of rows to retrieve in a single select query + progress: The update's progress dict. + batch_size: The batch size for this update. Returns: - The number of deleted rows + The number of rows deleted. """ - def _remove_deleted_devices_from_device_inbox_txn( + def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, - ) -> int: - """stream_id is not unique - we need to use an inclusive `stream_id >= ?` clause, - since we might not have deleted all dead device messages for the stream_id - returned from the previous query + ) -> Tuple[int, bool]: - Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, - to avoid problems of deleting a large number of rows all at once - due to a single device having lots of device messages. - """ + if "max_stream_id" in progress: + max_stream_id = progress["max_stream_id"] + else: + txn.execute("SELECT max(stream_id) FROM device_inbox") + # There's a type mismatch here between how we want to type the row and + # what fetchone says it returns, but we silence it because we know that + # res can't be None. + res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment] + if res[0] is None: + # this can only happen if the `device_inbox` table is empty, in which + # case we have no work to do. + return 0, True + else: + max_stream_id = res[0] - last_stream_id = progress.get("stream_id", 0) + start = progress.get("stream_id", 0) + stop = start + batch_size + # delete rows in `device_inbox` which do *not* correspond to a known, + # unhidden device. sql = """ - SELECT device_id, user_id, stream_id - FROM device_inbox + DELETE FROM device_inbox WHERE - stream_id >= ? - AND (device_id, user_id) NOT IN ( - SELECT device_id, user_id FROM devices + stream_id >= ? AND stream_id < ? + AND NOT EXISTS ( + SELECT * FROM devices d + WHERE + d.device_id=device_inbox.device_id + AND d.user_id=device_inbox.user_id + AND NOT hidden ) - ORDER BY stream_id - LIMIT ? - """ - - txn.execute(sql, (last_stream_id, batch_size)) - rows = txn.fetchall() - - num_deleted = 0 - for row in rows: - num_deleted += self.db_pool.simple_delete_txn( - txn, - "device_inbox", - {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, - ) + """ - if rows: - # send more than stream_id to progress - # otherwise it can happen in large deployments that - # no change of status is visible in the log file - # it may be that the stream_id does not change in several runs - self.db_pool.updates._background_update_progress_txn( - txn, - self.REMOVE_DELETED_DEVICES, - { - "device_id": rows[-1][0], - "user_id": rows[-1][1], - "stream_id": rows[-1][2], - }, - ) - - return num_deleted + txn.execute(sql, (start, stop)) - number_deleted = await self.db_pool.runInteraction( - "_remove_deleted_devices_from_device_inbox", - _remove_deleted_devices_from_device_inbox_txn, - ) - - # The task is finished when no more lines are deleted. - if not number_deleted: - await self.db_pool.updates._end_background_update( - self.REMOVE_DELETED_DEVICES + self.db_pool.updates._background_update_progress_txn( + txn, + self.REMOVE_DEAD_DEVICES_FROM_INBOX, + { + "stream_id": stop, + "max_stream_id": max_stream_id, + }, ) - return number_deleted - - async def _remove_hidden_devices_from_device_inbox( - self, progress: JsonDict, batch_size: int - ) -> int: - """A background update that deletes all device_inboxes for hidden devices. - - This should only need to be run once (when users upgrade to v1.47.0) - - Args: - progress: JsonDict used to store progress of this background update - batch_size: the maximum number of rows to retrieve in a single select query - - Returns: - The number of deleted rows - """ - - def _remove_hidden_devices_from_device_inbox_txn( - txn: LoggingTransaction, - ) -> int: - """stream_id is not unique - we need to use an inclusive `stream_id >= ?` clause, - since we might not have deleted all hidden device messages for the stream_id - returned from the previous query - - Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, - to avoid problems of deleting a large number of rows all at once - due to a single device having lots of device messages. - """ - - last_stream_id = progress.get("stream_id", 0) - - sql = """ - SELECT device_id, user_id, stream_id - FROM device_inbox - WHERE - stream_id >= ? - AND (device_id, user_id) IN ( - SELECT device_id, user_id FROM devices WHERE hidden = ? - ) - ORDER BY stream_id - LIMIT ? - """ - - txn.execute(sql, (last_stream_id, True, batch_size)) - rows = txn.fetchall() - - num_deleted = 0 - for row in rows: - num_deleted += self.db_pool.simple_delete_txn( - txn, - "device_inbox", - {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, - ) - - if rows: - # We don't just save the `stream_id` in progress as - # otherwise it can happen in large deployments that - # no change of status is visible in the log file, as - # it may be that the stream_id does not change in several runs - self.db_pool.updates._background_update_progress_txn( - txn, - self.REMOVE_HIDDEN_DEVICES, - { - "device_id": rows[-1][0], - "user_id": rows[-1][1], - "stream_id": rows[-1][2], - }, - ) - - return num_deleted + return stop > max_stream_id - number_deleted = await self.db_pool.runInteraction( - "_remove_hidden_devices_from_device_inbox", - _remove_hidden_devices_from_device_inbox_txn, + finished = await self.db_pool.runInteraction( + "_remove_devices_from_device_inbox_txn", + _remove_dead_devices_from_device_inbox_txn, ) - # The task is finished when no more lines are deleted. - if not number_deleted: + if finished: await self.db_pool.updates._end_background_update( - self.REMOVE_HIDDEN_DEVICES + self.REMOVE_DEAD_DEVICES_FROM_INBOX, ) - return number_deleted + return batch_size class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9ccc66e589..838a2a6a3d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} + async def get_devices_by_auth_provider_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> List[Dict[str, Any]]: + """Retrieve the list of devices associated with a SSO IdP session ID. + + Args: + auth_provider_id: The SSO IdP ID as defined in the server config + auth_provider_session_id: The session ID within the IdP + Returns: + A list of dicts containing the device_id and the user_id of each device + """ + return await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ) + @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -253,7 +274,9 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results @@ -1070,7 +1093,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: Optional[str] + self, + user_id: str, + device_id: str, + initial_device_display_name: Optional[str], + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1079,6 +1107,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1115,6 +1145,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + if auth_provider_id and auth_provider_session_id: + await self.db_pool.simple_insert( + "device_auth_providers", + values={ + "user_id": user_id, + "device_id": device_id, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="store_device_auth_provider", + ) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1168,6 +1210,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) + self.db_pool.simple_delete_many_txn( + txn, + table="device_auth_providers", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 6daf8b8ffb..a3442814d7 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -13,17 +13,18 @@ # limitations under the License. from collections import namedtuple -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) -class DirectoryWorkerStore(SQLBaseStore): +class DirectoryWorkerStore(CacheInvalidationWorkerStore): async def get_association_from_room_alias( self, room_alias: RoomAlias ) -> Optional[RoomAliasMapping]: @@ -91,7 +92,7 @@ class DirectoryWorkerStore(SQLBaseStore): creator: Optional user_id of creator. """ - def alias_txn(txn): + def alias_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, "room_aliases", @@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - async def delete_room_alias(self, room_alias: RoomAlias) -> str: + async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: + def _delete_room_alias_txn( + self, txn: LoggingTransaction, room_alias: RoomAlias + ) -> Optional[str]: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), @@ -173,9 +176,9 @@ class DirectoryStore(DirectoryWorkerStore): If None, the creator will be left unchanged. """ - def _update_aliases_for_room_txn(txn): + def _update_aliases_for_room_txn(txn: LoggingTransaction) -> None: update_creator_sql = "" - sql_params = (new_room_id, old_room_id) + sql_params: Tuple[str, ...] = (new_room_id, old_room_id) if creator: update_creator_sql = ", creator = ?" sql_params = (new_room_id, creator, old_room_id) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a95ac34f09..b06c1dc45b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): fallback_keys: the keys to set. This is a map from key ID (which is of the form "algorithm:id") to key data. """ + await self.db_pool.runInteraction( + "set_e2e_fallback_keys_txn", + self._set_e2e_fallback_keys_txn, + user_id, + device_id, + fallback_keys, + ) + + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) + ) + + def _set_e2e_fallback_keys_txn( + self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded for key_id, fallback_key in fallback_keys.items(): algorithm, key_id = key_id.split(":", 1) - await self.db_pool.simple_upsert( - "e2e_fallback_keys_json", + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_fallback_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, }, - values={ - "key_id": key_id, - "key_json": json_encoder.encode(fallback_key), - "used": False, - }, - desc="set_e2e_fallback_key", + retcol="key_json", + allow_none=True, ) - await self.invalidate_cache_and_stream( - "get_e2e_unused_fallback_key_types", (user_id, device_id) - ) + new_key_json = encode_canonical_json(fallback_key).decode("utf-8") + + # If the uploaded key is the same as the current fallback key, + # don't do anything. This prevents marking the key as unused if it + # was already used. + if old_key_json != new_key_json: + self.db_pool.simple_upsert_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + ) @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c58f7dd009..df70524fef 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1672,9 +1672,9 @@ class EventFederationStore(EventFederationWorkerStore): DELETE FROM event_auth WHERE event_id IN ( SELECT event_id FROM events - LEFT JOIN state_events USING (room_id, event_id) + LEFT JOIN state_events AS se USING (room_id, event_id) WHERE ? <= stream_ordering AND stream_ordering < ? - AND state_key IS null + AND se.state_key IS null ) """ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index d957e770dc..3efdd0c920 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr +from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json @@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] +class BasePushAction(TypedDict): + event_id: str + actions: List[Union[dict, str]] + + +class HttpPushAction(BasePushAction): + room_id: str + stream_ordering: int + + +class EmailPushAction(HttpPushAction): + received_ts: Optional[int] + + def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[dict]: + ) -> List[HttpPushAction]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. @@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): min_stream_ordering: int, max_stream_ordering: int, limit: int = 20, - ) -> List[dict]: + ) -> List[EmailPushAction]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3790a52a89..28d9d99b46 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1,6 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ # limitations under the License. import itertools import logging -from collections import OrderedDict, namedtuple +from collections import OrderedDict from typing import ( TYPE_CHECKING, Any, @@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder @@ -64,9 +65,6 @@ event_counter = Counter( ) -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -108,23 +106,30 @@ class PersistEventsStore: self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id - # Ideally we'd move these ID gens here, unfortunately some other ID - # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen - self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen - # This should only exist on instances that are configured to write assert ( hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" + # Since we have been configured to write, we ought to have id generators, + # rather than id trackers. + assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator) + assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator) + + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen + self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], + *, current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], - backfilled: bool = False, + use_negative_stream_ordering: bool = False, + inhibit_local_membership_updates: bool = False, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -137,7 +142,14 @@ class PersistEventsStore: room state new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. - backfilled + use_negative_stream_ordering: Whether to start stream_ordering on + the negative side and decrement. This should be set as True + for backfilled events because backfilled events get a negative + stream ordering so they don't come down incremental `/sync`. + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. Returns: Resolves when the events have been persisted @@ -159,7 +171,7 @@ class PersistEventsStore: # # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. - if backfilled: + if use_negative_stream_ordering: stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) @@ -176,13 +188,13 @@ class PersistEventsStore: "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, - backfilled=backfilled, + inhibit_local_membership_updates=inhibit_local_membership_updates, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(events_and_contexts)) - if not backfilled: + if stream < 0: # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( @@ -316,8 +328,9 @@ class PersistEventsStore: def _persist_events_txn( self, txn: LoggingTransaction, + *, events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool, + inhibit_local_membership_updates: bool = False, state_delta_for_room: Optional[Dict[str, DeltaState]] = None, new_forward_extremeties: Optional[Dict[str, List[str]]] = None, ): @@ -330,7 +343,10 @@ class PersistEventsStore: Args: txn events_and_contexts: events to persist - backfilled: True if the events were backfilled + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. @@ -363,9 +379,7 @@ class PersistEventsStore: events_and_contexts ) - self._update_room_depths_txn( - txn, events_and_contexts=events_and_contexts, backfilled=backfilled - ) + self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. @@ -398,7 +412,7 @@ class PersistEventsStore: txn, events_and_contexts=events_and_contexts, all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, + inhibit_local_membership_updates=inhibit_local_membership_updates, ) # We call this last as it assumes we've inserted the events into @@ -561,9 +575,9 @@ class PersistEventsStore: # fetch their auth event info. while missing_auth_chains: sql = """ - SELECT event_id, events.type, state_key, chain_id, sequence_number + SELECT event_id, events.type, se.state_key, chain_id, sequence_number FROM events - INNER JOIN state_events USING (event_id) + INNER JOIN state_events AS se USING (event_id) LEFT JOIN event_auth_chains USING (event_id) WHERE """ @@ -1200,7 +1214,6 @@ class PersistEventsStore: self, txn, events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool, ): """Update min_depth for each room @@ -1208,13 +1221,18 @@ class PersistEventsStore: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting - backfilled (bool): True if the events were backfilled """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) - if not backfilled: + # Then update the `stream_ordering` position to mark the latest + # event as the front of the room. This should not be done for + # backfilled events because backfilled events have negative + # stream_ordering and happened in the past so we know that we don't + # need to update the stream_ordering tip/front for the room. + assert event.internal_metadata.stream_ordering is not None + if event.internal_metadata.stream_ordering >= 0: txn.call_after( self.store._events_stream_cache.entity_has_changed, event.room_id, @@ -1427,7 +1445,12 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _update_metadata_tables_txn( - self, txn, events_and_contexts, all_events_and_contexts, backfilled + self, + txn, + *, + events_and_contexts, + all_events_and_contexts, + inhibit_local_membership_updates: bool = False, ): """Update all the miscellaneous tables for new events @@ -1439,7 +1462,10 @@ class PersistEventsStore: events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context. - backfilled (bool): True if the events were backfilled + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. """ # Insert all the push actions into the event_push_actions table. @@ -1513,7 +1539,7 @@ class PersistEventsStore: for event, _ in events_and_contexts if event.type == EventTypes.Member ], - backfilled=backfilled, + inhibit_local_membership_updates=inhibit_local_membership_updates, ) # Insert event_reference_hashes table. @@ -1553,11 +1579,13 @@ class PersistEventsStore: for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(_EventCacheEntry(event=event, redacted_event=None)) + to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) + self.store._get_event_cache.set( + (cache_entry.event.event_id,), cache_entry + ) txn.call_after(prefill) @@ -1638,11 +1666,22 @@ class PersistEventsStore: txn, table="event_reference_hashes", values=vals ) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database.""" + def _store_room_members_txn( + self, txn, events, *, inhibit_local_membership_updates: bool = False + ): + """ + Store a room member in the database. + Args: + txn: The transaction to use. + events: List of events to store. + inhibit_local_membership_updates: Stop the local_current_membership + from being updated by these events. This should be set to True + for backfilled events because backfilled events in the past do + not affect the current local state. + """ - def str_or_none(val: Any) -> Optional[str]: - return val if isinstance(val, str) else None + def non_null_str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) and "\u0000" not in val else None self.db_pool.simple_insert_many_txn( txn, @@ -1654,8 +1693,10 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": str_or_none(event.content.get("displayname")), - "avatar_url": str_or_none(event.content.get("avatar_url")), + "display_name": non_null_str_or_none( + event.content.get("displayname") + ), + "avatar_url": non_null_str_or_none(event.content.get("avatar_url")), } for event in events ], @@ -1680,7 +1721,7 @@ class PersistEventsStore: # band membership", like a remote invite or a rejection of a remote invite. if ( self.is_mine_id(event.state_key) - and not backfilled + and not inhibit_local_membership_updates and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): @@ -1694,34 +1735,33 @@ class PersistEventsStore: }, ) - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events + def _handle_event_relations( + self, txn: LoggingTransaction, event: EventBase + ) -> None: + """Handles inserting relation data during persistence of events Args: - txn - event (EventBase) + txn: The current database transaction. + event: The event which might have relations. """ relation = event.content.get("m.relates_to") if not relation: # No relations return + # Relations must have a type and parent event ID. rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - RelationTypes.THREAD, - ): - # Unknown relation type + if not isinstance(rel_type, str): return parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation + if not isinstance(parent_id, str): return - aggregation_key = relation.get("key") + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") self.db_pool.simple_insert_txn( txn, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index ae3a8a63e4..c88fd35e7f 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1,4 +1,4 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + # The event_thread_relation background update was replaced with the + # event_arbitrary_relations one, which handles any relation to avoid + # needed to potentially crawl the entire events table in the future. + self.db_pool.updates.register_noop_background_update("event_thread_relation") + self.db_pool.updates.register_background_update_handler( - "event_thread_relation", self._event_thread_relation + "event_arbitrary_relations", + self._event_arbitrary_relations, ) ################################################################################ @@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result - async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int: - """Background update handler which will store thread relations for existing events.""" + async def _event_arbitrary_relations( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update handler which will store previously unknown relations for existing events.""" last_event_id = progress.get("last_event_id", "") - def _event_thread_relation_txn(txn: LoggingTransaction) -> int: + def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int: + # Fetch events and then filter based on whether the event has a + # relation or not. txn.execute( """ SELECT event_id, json FROM event_json - LEFT JOIN event_relations USING (event_id) - WHERE event_id > ? AND event_relations.event_id IS NULL + WHERE event_id > ? ORDER BY event_id LIMIT ? """, (last_event_id, batch_size), ) results = list(txn) - missing_thread_relations = [] + # (event_id, parent_id, rel_type) for each relation + relations_to_insert: List[Tuple[str, str, str]] = [] for (event_id, event_json_raw) in results: try: event_json = db_to_json(event_json_raw) @@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) continue - # If there's no relation (or it is not a thread), skip! + # If there's no relation, skip! relates_to = event_json["content"].get("m.relates_to") if not relates_to or not isinstance(relates_to, dict): continue - if relates_to.get("rel_type") != RelationTypes.THREAD: + + # If the relation type or parent event ID is not a string, skip it. + # + # Do not consider relation types that have existed for a long time, + # since they will already be listed in the `event_relations` table. + rel_type = relates_to.get("rel_type") + if not isinstance(rel_type, str) or rel_type in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): continue - # Get the parent ID. parent_id = relates_to.get("event_id") if not isinstance(parent_id, str): continue - missing_thread_relations.append((event_id, parent_id)) + relations_to_insert.append((event_id, parent_id, rel_type)) + + # Insert the missing data, note that we upsert here in case the event + # has already been processed. + if relations_to_insert: + self.db_pool.simple_upsert_many_txn( + txn=txn, + table="event_relations", + key_names=("event_id",), + key_values=[(r[0],) for r in relations_to_insert], + value_names=("relates_to_id", "relation_type"), + value_values=[r[1:] for r in relations_to_insert], + ) - # Insert the missing data. - self.db_pool.simple_insert_many_txn( - txn=txn, - table="event_relations", - values=[ - { - "event_id": event_id, - "relates_to_Id": parent_id, - "relation_type": RelationTypes.THREAD, - } - for event_id, parent_id in missing_thread_relations - ], - ) + # Iterate the parent IDs and invalidate caches. + for parent_id in {r[1] for r in relations_to_insert}: + cache_tuple = (parent_id,) + self._invalidate_cache_and_stream( + txn, self.get_relations_for_event, cache_tuple + ) + self._invalidate_cache_and_stream( + txn, self.get_aggregation_groups_for_event, cache_tuple + ) + self._invalidate_cache_and_stream( + txn, self.get_thread_summary, cache_tuple + ) if results: latest_event_id = results[-1][0] self.db_pool.updates._background_update_progress_txn( - txn, "event_thread_relation", {"last_event_id": latest_event_id} + txn, "event_arbitrary_relations", {"last_event_id": latest_event_id} ) return len(results) num_rows = await self.db_pool.runInteraction( - desc="event_thread_relation", func=_event_thread_relation_txn + desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn ) if not num_rows: - await self.db_pool.updates._end_background_update("event_thread_relation") + await self.db_pool.updates._end_background_update( + "event_arbitrary_relations" + ) return num_rows diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 6d2688d711..68901b4335 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import Dict, List +from typing import Any, Dict, List from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main import CacheInvalidationWorkerStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore logger = logging.getLogger(__name__) -class EventForwardExtremitiesStore(SQLBaseStore): +class EventForwardExtremitiesStore( + EventFederationWorkerStore, + CacheInvalidationWorkerStore, +): async def delete_forward_extremities_for_room(self, room_id: str) -> int: """Delete any extra forward extremities for a room. @@ -31,7 +36,7 @@ class EventForwardExtremitiesStore(SQLBaseStore): Returns count deleted. """ - def delete_forward_extremities_for_room_txn(txn): + def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int: # First we need to get the event_id to not delete sql = """ SELECT event_id FROM event_forward_extremities @@ -82,10 +87,14 @@ class EventForwardExtremitiesStore(SQLBaseStore): delete_forward_extremities_for_room_txn, ) - async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + async def get_forward_extremities_for_room( + self, room_id: str + ) -> List[Dict[str, Any]]: """Get list of forward extremities for a room.""" - def get_forward_extremities_for_room_txn(txn): + def get_forward_extremities_for_room_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c6bf316d5b..c7b660ac5a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -15,14 +15,18 @@ import logging import threading from typing import ( + TYPE_CHECKING, + Any, Collection, Container, Dict, Iterable, List, + NoReturn, Optional, Set, Tuple, + cast, overload, ) @@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, RoomVersions, ) from synapse.events import EventBase, make_event_from_dict @@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError @@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -# These values are used in the `enqueus_event` and `_do_fetch` methods to +# These values are used in the `enqueue_event` and `_fetch_loop` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster # on jki.re @@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge( @attr.s(slots=True, auto_attribs=True) -class _EventCacheEntry: +class EventCacheEntry: event: EventBase redacted_event: Optional[EventBase] @@ -129,7 +145,7 @@ class _EventRow: json: str internal_metadata: str format_version: Optional[int] - room_version_id: Optional[int] + room_version_id: Optional[str] rejected_reason: Optional[str] redactions: List[str] outlier: bool @@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore): # options controlling this. USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) + self._stream_id_gen: AbstractStreamIdTracker + self._backfill_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache = LruCache( + self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore): # ID to cache entry. Note that the returned dict may not have the # requested event in it if the event isn't in the DB. self._current_event_fetches: Dict[ - str, ObservableDeferred[Dict[str, _EventCacheEntry]] + str, ObservableDeferred[Dict[str, EventCacheEntry]] ] = {} self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] + self._event_fetch_list: List[ + Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] + ] = [] self._event_fetch_ongoing = 0 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): + def get_chain_id_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] + return cast(Tuple[int], txn.fetchone())[0] self.event_chain_id_gen = build_sequence_generator( db_conn, @@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore): id_column="chain_id", ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: @@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: Literal[False] = False, - check_room_id: Optional[str] = None, + get_prev_content: bool = ..., + allow_rejected: bool = ..., + allow_none: Literal[False] = ..., + check_room_id: Optional[str] = ..., ) -> EventBase: ... @@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore): self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, - get_prev_content: bool = False, - allow_rejected: bool = False, - allow_none: Literal[True] = False, - check_room_id: Optional[str] = None, + get_prev_content: bool = ..., + allow_rejected: bool = ..., + allow_none: Literal[True] = ..., + check_room_id: Optional[str] = ..., ) -> Optional[EventBase]: ... @@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_events( self, - event_ids: Iterable[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore): async def _get_events_from_cache_or_db( self, event_ids: Iterable[str], allow_rejected: bool = False - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore): # same dict into itself N times). already_fetching_ids: Set[str] = set() already_fetching_deferreds: Set[ - ObservableDeferred[Dict[str, _EventCacheEntry]] + ObservableDeferred[Dict[str, EventCacheEntry]] ] = set() for event_id in missing_events_ids: @@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore): # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ - Dict[str, _EventCacheEntry] - ] = ObservableDeferred(defer.Deferred()) + Dict[str, EventCacheEntry] + ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id): + def _invalidate_get_event_cache(self, event_id: str) -> None: self._get_event_cache.invalidate((event_id,)) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True - ) -> Dict[str, _EventCacheEntry]: + ) -> Dict[str, EventCacheEntry]: """Fetch events from the caches. May return rejected events. @@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _do_fetch(self, conn: Connection) -> None: + def _maybe_start_fetch_thread(self) -> None: + """Starts an event fetch thread if we are not yet at the maximum number.""" + with self._event_fetch_lock: + if ( + self._event_fetch_list + and self._event_fetch_ongoing < EVENT_QUEUE_THREADS + ): + self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + # `_event_fetch_ongoing` is decremented in `_fetch_thread`. + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process("fetch_events", self._fetch_thread) + + async def _fetch_thread(self) -> None: + """Services requests for events from `_event_fetch_list`.""" + exc = None + try: + await self.db_pool.runWithConnection(self._fetch_loop) + except BaseException as e: + exc = e + raise + finally: + should_restart = False + event_fetches_to_fail = [] + with self._event_fetch_lock: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + + # There may still be work remaining in `_event_fetch_list` if we + # failed, or it was added in between us deciding to exit and + # decrementing `_event_fetch_ongoing`. + if self._event_fetch_list: + if exc is None: + # We decided to exit, but then some more work was added + # before `_event_fetch_ongoing` was decremented. + # If a new event fetch thread was not started, we should + # restart ourselves since the remaining event fetch threads + # may take a while to get around to the new work. + # + # Unfortunately it is not possible to tell whether a new + # event fetch thread was started, so we restart + # unconditionally. If we are unlucky, we will end up with + # an idle fetch thread, but it will time out after + # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds + # in any case. + # + # Note that multiple fetch threads may run down this path at + # the same time. + should_restart = True + elif isinstance(exc, Exception): + if self._event_fetch_ongoing == 0: + # We were the last remaining fetcher and failed. + # Fail any outstanding fetches since no one else will + # handle them. + event_fetches_to_fail = self._event_fetch_list + self._event_fetch_list = [] + else: + # We weren't the last remaining fetcher, so another + # fetcher will pick up the work. This will either happen + # after their existing work, however long that takes, + # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if + # they are idle. + pass + else: + # The exception is a `SystemExit`, `KeyboardInterrupt` or + # `GeneratorExit`. Don't try to do anything clever here. + pass + + if should_restart: + # We exited cleanly but noticed more work. + self._maybe_start_fetch_thread() + + if event_fetches_to_fail: + # We were the last remaining fetcher and failed. + # Fail any outstanding fetches since no one else will handle them. + assert exc is not None + with PreserveLoggingContext(): + for _, deferred in event_fetches_to_fail: + deferred.errback(exc) + + def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - try: - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - break - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + # There are no requests waiting. If we haven't yet reached the + # maximum iteration limit, wait for some more requests to turn up. + # Otherwise, bail out. + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + return + + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - self._fetch_event_list(conn, event_list) - finally: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + self._fetch_event_list(conn, event_list) def _fetch_event_list( - self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] + self, + conn: LoggingDatabaseConnection, + event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], ) -> None: """Handle a load of requests from the _event_fetch_list queue @@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore): ) # We only want to resolve deferreds from the main thread - def fire(): + def fire() -> None: for _, d in event_list: d.callback(row_dict) @@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs, exc): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(exc) + def fire_errback(exc: Exception) -> None: + for _, d in event_list: + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list, e) + self.hs.get_reactor().callFromThread(fire_errback, e) async def _get_events_from_db( - self, event_ids: Iterable[str] - ) -> Dict[str, _EventCacheEntry]: + self, event_ids: Collection[str] + ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the database. May return rejected events. @@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ - fetched_events = {} + fetched_event_ids: Set[str] = set() + fetched_events: Dict[str, _EventRow] = {} events_to_fetch = event_ids while events_to_fetch: row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events - redaction_ids = set() + redaction_ids: Set[str] = set() for event_id in events_to_fetch: row = row_map.get(event_id) - fetched_events[event_id] = row + fetched_event_ids.add(event_id) if row: + fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_events.keys()) + events_to_fetch = redaction_ids.difference(fetched_event_ids) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) # build a map from event_id to EventBase - event_map = {} + event_map: Dict[str, EventBase] = {} for event_id, row in fetched_events.items(): - if not row: - continue assert row.event_id == event_id rejected_reason = row.rejected_reason @@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row.room_version_id + room_version: Optional[RoomVersion] if not room_version_id: # this should only happen for out-of-band membership events which # arrived before #6983 landed. For all other events, we should have @@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore): # finally, we can decide whether each one needs redacting, and build # the cache entries. - result_map = {} + result_map: Dict[str, EventCacheEntry] = {} for event_id, original_ev in event_map.items(): redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) - cache_entry = _EventCacheEntry( + cache_entry = EventCacheEntry( event=original_ev, redacted_event=redacted_event ) @@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: + async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore): that weren't requested. """ - events_d = defer.Deferred() + events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() with self._event_fetch_lock: self._event_fetch_list.append((events, events_d)) - self._event_fetch_lock.notify() - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process( - "fetch_events", self.db_pool.runWithConnection, self._do_fetch - ) + self._maybe_start_fetch_thread() logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): @@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - async def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: """Given a list of event ids, check if we have already processed and stored them as non outliers. """ @@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore): event_ids: events we are looking for Returns: - set[str]: The events we have already seen. + The set of events we have already seen. """ res = await self._have_seen_events_dict( (room_id, event_id) for event_id in event_ids @@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore): } results = {x: True for x in cache_results} - def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): + def have_seen_events_txn( + txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] + ) -> None: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore): return results @cached(max_entries=100000, tree=True) - async def have_seen_event(self, room_id: str, event_id: str): + async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: # this only exists for the benefit of the @cachedList descriptor on # _have_seen_events_dict raise NotImplementedError() - def _get_current_state_event_counts_txn(self, txn, room_id): + def _get_current_state_event_counts_txn( + self, txn: LoggingTransaction, room_id: str + ) -> int: """ See get_current_state_event_counts. """ @@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - async def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id: str) -> Dict[str, float]: """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore): more resources. Args: - room_id (str) + room_id: The room ID to query. Returns: - dict[str:int] of complexity version to complexity. + dict[str:float] of complexity version to complexity. """ state_events = await self.get_current_state_event_counts(room_id) @@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_events_token(self): + def get_current_events_token(self) -> int: """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: """Returns new events, for the Events replication stream Args: @@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_all_new_forward_event_rows(txn): + def get_all_new_forward_event_rows( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN state_events AS se USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, instance_name, limit)) - return txn.fetchall() + return cast( + List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + ) return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows @@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: """Returns de-outliered events, for the Events replication stream Args: @@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore): EventsStreamRow. """ - def get_ex_outlier_stream_rows_txn(txn): + def get_ex_outlier_stream_rows_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN state_events AS se USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " LEFT JOIN room_memberships USING (event_id)" " LEFT JOIN rejections USING (event_id)" @@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name)) - return txn.fetchall() + return cast( + List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + ) return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn @@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_backfill_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. @@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_new_backfill_event_rows(txn): + def get_all_new_backfill_event_rows( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " se.state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN state_events AS se USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " AND instance_name = ?" @@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates = [(row[0], row[1:]) for row in txn] + new_event_updates: List[ + Tuple[int, Tuple[str, str, str, str, str, str]] + ] = [] + row: Tuple[int, str, str, str, str, str, str] + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + new_event_updates.append((row[0], row[1:])) limited = False if len(new_event_updates) == limit: @@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore): sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " se.state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN state_events AS se USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" @@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound, instance_name)) - new_event_updates.extend((row[0], row[1:]) for row in txn) + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + new_event_updates.append((row[0], row[1:])) if len(new_event_updates) >= limit: upper_bound = new_event_updates[-1][0] @@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_updated_current_state_deltas( self, instance_name: str, from_token: int, to_token: int, target_row_count: int - ) -> Tuple[List[Tuple], int, bool]: + ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: """Fetch updates from current_state_delta_stream Args: @@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore): * `limited` is whether there are more updates to fetch. """ - def get_all_updated_current_state_deltas_txn(txn): + def get_all_updated_current_state_deltas_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str, str]]: sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream @@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore): ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) - def get_deltas_for_stream_id_txn(txn, stream_id): + def get_deltas_for_stream_id_txn( + txn: LoggingTransaction, stream_id: int + ) -> List[Tuple[int, str, str, str, str]]: sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE stream_id = ? """ txn.execute(sql, [stream_id]) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows: List[Tuple] = await self.db_pool.runInteraction( + rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - async def is_event_after(self, event_id1, event_id2): + async def is_event_after(self, event_id1: str, event_id2: str) -> bool: """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cached(max_entries=5000) - async def get_event_ordering(self, event_id): + async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], @@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore): None otherwise. """ - def get_next_event_to_expire_txn(txn): + def get_next_event_to_expire_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, int]]: txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore): """ ) - return txn.fetchone() + return cast(Optional[Tuple[str, int]], txn.fetchone()) return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore): return mapping @wrap_as_background_process("_cleanup_old_transaction_ids") - async def _cleanup_old_transaction_ids(self): + async def _cleanup_old_transaction_ids(self) -> None: """Cleans out transaction id mappings older than 24hrs.""" - def _cleanup_old_transaction_ids_txn(txn): + def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: sql = """ DELETE FROM event_txn_id WHERE inserted_ts < ? @@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore): "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, ) + + async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: + """Check if the given event is next to a backward gap of missing events. + <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages> + + Args: + room_id: room where the event lives + event_id: event to check + + Returns: + Boolean indicating whether it's an extremity + """ + + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: + # If the event in question has any of its prev_events listed as a + # backward extremity, it's next to a gap. + # + # We can't just check the backward edges in `event_edges` because + # when we persist events, we will also record the prev_events as + # edges to the event in question regardless of whether we have those + # prev_events yet. We need to check whether those prev_events are + # backward extremities, also known as gaps, that need to be + # backfilled. + backward_extremity_query = """ + SELECT 1 FROM event_backward_extremities + WHERE + room_id = ? + AND %s + LIMIT 1 + """ + + # If the event in question is a backward extremity or has any of its + # prev_events listed as a backward extremity, it's next to a + # backward gap. + clause, args = make_in_list_sql_clause( + self.database_engine, + "event_id", + [event.event_id] + list(event.prev_event_ids()), + ) + + txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) + backward_extremities = txn.fetchall() + + # We consider any backward extremity as a backward gap + if len(backward_extremities): + return True + + return False + + return await self.db_pool.runInteraction( + "is_event_next_to_backward_gap_txn", + is_event_next_to_backward_gap_txn, + ) + + async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: + """Check if the given event is next to a forward gap of missing events. + The gap in front of the latest events is not considered a gap. + <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages> + <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages> + + Args: + room_id: room where the event lives + event_id: event to check + + Returns: + Boolean indicating whether it's an extremity + """ + + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: + # If the event in question is a forward extremity, we will just + # consider any potential forward gap as not a gap since it's one of + # the latest events in the room. + # + # `event_forward_extremities` does not include backfilled or outlier + # events so we can't rely on it to find forward gaps. We can only + # use it to determine whether a message is the latest in the room. + # + # We can't combine this query with the `forward_edge_query` below + # because if the event in question has no forward edges (isn't + # referenced by any other event's prev_events) but is in + # `event_forward_extremities`, we don't want to return 0 rows and + # say it's next to a gap. + forward_extremity_query = """ + SELECT 1 FROM event_forward_extremities + WHERE + room_id = ? + AND event_id = ? + LIMIT 1 + """ + + # Check to see whether the event in question is already referenced + # by another event. If we don't see any edges, we're next to a + # forward gap. + forward_edge_query = """ + SELECT 1 FROM event_edges + /* Check to make sure the event referencing our event in question is not rejected */ + LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + WHERE + event_edges.room_id = ? + AND event_edges.prev_event_id = ? + /* It's not a valid edge if the event referencing our event in + * question is rejected. + */ + AND rejections.event_id IS NULL + LIMIT 1 + """ + + # We consider any forward extremity as the latest in the room and + # not a forward gap. + # + # To expand, even though there is technically a gap at the front of + # the room where the forward extremities are, we consider those the + # latest messages in the room so asking other homeservers for more + # is useless. The new latest messages will just be federated as + # usual. + txn.execute(forward_extremity_query, (event.room_id, event.event_id)) + forward_extremities = txn.fetchall() + if len(forward_extremities): + return False + + # If there are no forward edges to the event in question (another + # event hasn't referenced this event in their prev_events), then we + # assume there is a forward gap in the history. + txn.execute(forward_edge_query, (event.room_id, event.event_id)) + forward_edges = txn.fetchall() + if not len(forward_edges): + return True + + return False + + return await self.db_pool.runInteraction( + "is_event_next_to_gap_txn", + is_event_next_to_gap_txn, + ) + + async def get_event_id_for_timestamp( + self, room_id: str, timestamp: int, direction: str + ) -> Optional[str]: + """Find the closest event to the given timestamp in the given direction. + + Args: + room_id: Room to fetch the event from + timestamp: The point in time (inclusive) we should navigate from in + the given direction to find the closest event. + direction: ["f"|"b"] to indicate whether we should navigate forward + or backward from the given timestamp to find the closest event. + + Returns: + The closest event_id otherwise None if we can't find any event in + the given direction. + """ + + sql_template = """ + SELECT event_id FROM events + LEFT JOIN rejections USING (event_id) + WHERE + origin_server_ts %s ? + AND room_id = ? + /* Make sure event is not rejected */ + AND rejections.event_id IS NULL + ORDER BY origin_server_ts %s + LIMIT 1; + """ + + def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" + + txn.execute( + sql_template % (comparison_operator, order), (timestamp, room_id) + ) + row = txn.fetchone() + if row: + (event_id,) = row + return event_id + + return None + + if direction not in ("f", "b"): + raise ValueError("Unknown direction: %s" % (direction,)) + + return await self.db_pool.runInteraction( + "get_event_id_for_timestamp_txn", + get_event_id_for_timestamp_txn, + ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 434986fa64..cf842803bc 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore): # Need an atomic transaction to SELECT the maximal ID so far then # INSERT a new one - def _do_txn(txn): + def _do_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?" @@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore): sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) - max_id = txn.fetchone()[0] + max_id = txn.fetchone()[0] # type: ignore[index] if max_id is None: filter_id = 0 else: diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 3d0df0cbd4..a540f7fb26 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from types import TracebackType -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Tuple, Type from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -62,7 +62,9 @@ class LockStore(SQLBaseStore): # A map from `(lock_name, lock_key)` to the token of any locks that we # think we currently hold. - self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary() + self._live_tokens: WeakValueDictionary[ + Tuple[str, str], Lock + ] = WeakValueDictionary() # When we shut down we want to remove the locks. Technically this can # lead to a race, as we may drop the lock while we are still processing. diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 717487be28..1b076683f7 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,10 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -46,7 +61,12 @@ class MediaSortOrder(Enum): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -102,13 +122,15 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): self._drop_media_index_without_method, ) - async def _drop_media_index_without_method(self, progress, batch_size): + async def _drop_media_index_without_method( + self, progress: JsonDict, batch_size: int + ) -> int: """background update handler which removes the old constraints. Note that this is only run on postgres. """ - def f(txn): + def f(txn: LoggingTransaction) -> None: txn.execute( "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" ) @@ -126,7 +148,12 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -174,7 +201,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): plus the total count of all the user's media """ - def get_local_media_by_user_paginate_txn(txn): + def get_local_media_by_user_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -184,14 +213,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): else: order = "ASC" - args = [user_id] + args: List[Union[str, int]] = [user_id] sql = """ SELECT COUNT(*) as total_media FROM local_media_repository WHERE user_id = ? """ txn.execute(sql, args) - count = txn.fetchone()[0] + count = txn.fetchone()[0] # type: ignore[index] sql = """ SELECT @@ -268,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) sql += sql_keep - def _get_local_media_before_txn(txn): + def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts, before_ts, size_gt)) return [row[0] for row in txn] @@ -278,13 +307,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_local_media( self, - media_id, - media_type, - time_now_ms, - upload_name, - media_length, - user_id, - url_cache=None, + media_id: str, + media_type: str, + time_now_ms: int, + upload_name: Optional[str], + media_length: int, + user_id: UserID, + url_cache: Optional[str] = None, ) -> None: await self.db_pool.simple_insert( "local_media_repository", @@ -315,7 +344,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): None if the URL isn't cached. """ - def get_url_cache_txn(txn): + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # get the most recently cached result (relative to the given ts) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" @@ -359,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts - ): + ) -> None: await self.db_pool.simple_insert( "local_media_repository_url_cache", { @@ -390,13 +419,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_local_thumbnail( self, - media_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): + media_id: str, + thumbnail_width: int, + thumbnail_height: int, + thumbnail_type: str, + thumbnail_method: str, + thumbnail_length: int, + ) -> None: await self.db_pool.simple_upsert( table="local_media_repository_thumbnails", keyvalues={ @@ -430,14 +459,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_cached_remote_media( self, - origin, - media_id, - media_type, - media_length, - time_now_ms, - upload_name, - filesystem_id, - ): + origin: str, + media_id: str, + media_type: str, + media_length: int, + time_now_ms: int, + upload_name: Optional[str], + filesystem_id: str, + ) -> None: await self.db_pool.simple_insert( "remote_media_cache", { @@ -458,7 +487,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): local_media: Iterable[str], remote_media: Iterable[Tuple[str, str]], time_ms: int, - ): + ) -> None: """Updates the last access time of the given media Args: @@ -467,7 +496,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): time_ms: Current time in milliseconds """ - def update_cache_txn(txn): + def update_cache_txn(txn: LoggingTransaction) -> None: sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" " WHERE media_origin = ? AND media_id = ?" @@ -488,7 +517,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) @@ -542,15 +571,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def store_remote_media_thumbnail( self, - origin, - media_id, - filesystem_id, - thumbnail_width, - thumbnail_height, - thumbnail_type, - thumbnail_method, - thumbnail_length, - ): + origin: str, + media_id: str, + filesystem_id: str, + thumbnail_width: int, + thumbnail_height: int, + thumbnail_type: str, + thumbnail_method: str, + thumbnail_length: int, + ) -> None: await self.db_pool.simple_upsert( table="remote_media_cache_thumbnails", keyvalues={ @@ -566,7 +595,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_remote_media_thumbnail", ) - async def get_remote_media_before(self, before_ts): + async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]: sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" @@ -602,7 +631,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " LIMIT 500" ) - def _get_expired_url_cache_txn(txn): + def _get_expired_url_cache_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (now_ts,)) return [row[0] for row in txn] @@ -610,18 +639,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_expired_url_cache", _get_expired_url_cache_txn ) - async def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids: Collection[str]) -> None: if len(media_ids) == 0: return sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" - def _delete_url_cache_txn(txn): + def _delete_url_cache_txn(txn: LoggingTransaction) -> None: txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) - return await self.db_pool.runInteraction( - "delete_url_cache", _delete_url_cache_txn - ) + await self.db_pool.runInteraction("delete_url_cache", _delete_url_cache_txn) async def get_url_cache_media_before(self, before_ts: int) -> List[str]: sql = ( @@ -631,7 +658,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " LIMIT 500" ) - def _get_url_cache_media_before_txn(txn): + def _get_url_cache_media_before_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (before_ts,)) return [row[0] for row in txn] @@ -639,11 +666,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_url_cache_media_before", _get_url_cache_media_before_txn ) - async def delete_url_cache_media(self, media_ids): + async def delete_url_cache_media(self, media_ids: Collection[str]) -> None: if len(media_ids) == 0: return - def _delete_url_cache_media_txn(txn): + def _delete_url_cache_media_txn(txn: LoggingTransaction) -> None: sql = "DELETE FROM local_media_repository WHERE media_id = ?" txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) @@ -652,6 +679,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 2aac64901b..a46685219f 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,6 +1,21 @@ +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction class OpenIdStore(SQLBaseStore): @@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore): async def get_user_id_for_open_id_token( self, token: str, ts_now_ms: int ) -> Optional[str]: - def get_user_id_for_token_txn(txn): + def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]: sql = ( "SELECT user_id FROM open_id_tokens" " WHERE token = ? AND ? <= ts_valid_until_ms" diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index dd8e27e226..e197b7203e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo @@ -104,7 +105,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="update_remote_profile_cache", ) - async def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id: str) -> None: """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ @@ -116,9 +117,9 @@ class ProfileWorkerStore(SQLBaseStore): desc="delete_remote_profile_cache", ) - async def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool: """Check whether we are interested in a remote user's profile.""" - res = await self.db_pool.simple_select_one_onecol( + res: Optional[str] = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -139,13 +140,16 @@ class ProfileWorkerStore(SQLBaseStore): if res: return True + return False async def get_remote_profile_cache_entries_that_expire( self, last_checked: int ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked`""" - def _get_remote_profile_cache_entries_that_expire_txn(txn): + def _get_remote_profile_cache_entries_that_expire_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: sql = """ SELECT user_id, displayname, avatar_url FROM remote_profile_cache diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 3eb30944bf..91b0576b85 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): logger.info("[purge] looking for events to delete") - should_delete_expr = "state_key IS NULL" + should_delete_expr = "state_events.state_key IS NULL" should_delete_params: Tuple[Any, ...] = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fa782023d4..3b63267395 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + StreamIdGenerator, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -82,9 +85,9 @@ class PushRulesWorkerStore( super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: Union[ - StreamIdGenerator, SlavedIdTracker - ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id") + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c99f8aebdb..9c5625c8bb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,14 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ + def get_max_receipt_stream_id(self) -> int: + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): - if receipt_type != "m.read": + ) -> None: + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) @@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore): else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6c7d6ba508..e1ddf06916 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -84,28 +84,37 @@ class TokenLookupResult: return self.user_id -@attr.s(frozen=True, slots=True) +@attr.s(auto_attribs=True, frozen=True, slots=True) class RefreshTokenLookupResult: """Result of looking up a refresh token.""" - user_id = attr.ib(type=str) + user_id: str """The user this token belongs to.""" - device_id = attr.ib(type=str) + device_id: str """The device associated with this refresh token.""" - token_id = attr.ib(type=int) + token_id: int """The ID of this refresh token.""" - next_token_id = attr.ib(type=Optional[int]) + next_token_id: Optional[int] """The ID of the refresh token which replaced this one.""" - has_next_refresh_token_been_refreshed = attr.ib(type=bool) + has_next_refresh_token_been_refreshed: bool """True if the next refresh token was used for another refresh.""" - has_next_access_token_been_used = attr.ib(type=bool) + has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" + expiry_ts: Optional[int] + """The time at which the refresh token expires and can not be used. + If None, the refresh token doesn't expire.""" + + ultimate_session_expiry_ts: Optional[int] + """The time at which the session comes to an end and can no longer be + refreshed. + If None, the session can be refreshed indefinitely.""" + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( @@ -476,7 +485,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): shadow_banned: true iff the user is to be shadow-banned, false otherwise. """ - def set_shadow_banned_txn(txn): + def set_shadow_banned_txn(txn: LoggingTransaction) -> None: user_id = user.to_string() self.db_pool.simple_update_one_txn( txn, @@ -1198,8 +1207,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts = now_ms + self._account_validity_period if use_delta: + assert self._account_validity_startup_job_max_delta is not None expiration_ts = random.randrange( - expiration_ts - self._account_validity_startup_job_max_delta, + int(expiration_ts - self._account_validity_startup_job_max_delta), expiration_ts, ) @@ -1625,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): rt.user_id, rt.device_id, rt.next_token_id, - (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, - at.used has_next_access_token_been_used + (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, + at.used AS has_next_access_token_been_used, + rt.expiry_ts, + rt.ultimate_session_expiry_ts FROM refresh_tokens rt LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id @@ -1647,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): has_next_refresh_token_been_refreshed=row[4], # This column is nullable, ensure it's a boolean has_next_access_token_been_used=(row[5] or False), + expiry_ts=row[6], + ultimate_session_expiry_ts=row[7], ) return await self.db_pool.runInteraction( @@ -1728,11 +1742,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): ) self.db_pool.updates.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather + "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - self.db_pool.updates.register_background_update_handler( - "users_set_deactivated_flag", self._background_update_set_deactivated_flag + self.db_pool.updates.register_noop_background_update( + "user_threepids_grandfather" ) self.db_pool.updates.register_background_index_update( @@ -1805,35 +1819,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return nb_processed - async def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.registration.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - await self.db_pool.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - await self.db_pool.updates._end_background_update("user_threepids_grandfather") - - return 1 - async def set_user_deactivated_status( self, user_id: str, deactivated: bool ) -> None: @@ -1943,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: str, token: str, device_id: Optional[str], + expiry_ts: Optional[int], + ultimate_session_expiry_ts: Optional[int], ) -> int: """Adds a refresh token for the given user. @@ -1950,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): user_id: The user ID. token: The new access token to add. device_id: ID of the device to associate with the refresh token. + expiry_ts (milliseconds since the epoch): Time after which the + refresh token cannot be used. + If None, the refresh token never expires until it has been used. + ultimate_session_expiry_ts (milliseconds since the epoch): + Time at which the session will end and can not be extended any + further. + If None, the session can be refreshed indefinitely. Raises: StoreError if there was a problem adding this. Returns: @@ -1965,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "token": token, "next_token_id": None, + "expiry_ts": expiry_ts, + "ultimate_session_expiry_ts": ultimate_session_expiry_ts, }, desc="add_refresh_token_to_user", ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 907af10995..0a43acda07 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def event_includes_relation(self, event_id: str) -> bool: + """Check if the given event relates to another event. + + An event has a relation if it has a valid m.relates_to with a rel_type + and event_id in the content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$other_event_id" + } + } + } + + Args: + event_id: The event to check. + + Returns: + True if the event includes a valid relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"event_id": event_id}, + retcol="event_id", + allow_none=True, + desc="event_includes_relation", + ) + return result is not None + + async def event_is_target_of_relation(self, parent_id: str) -> bool: + """Check if the given event is the target of another event's relation. + + An event is the target of an event relation if it has a valid + m.relates_to with a rel_type and event_id pointing to parent_id in the + content: + + { + "content": { + "m.relates_to": { + "rel_type": "m.replace", + "event_id": "$parent_id" + } + } + } + + Args: + parent_id: The event to check. + + Returns: + True if the event is the target of another event's relation. + """ + + result = await self.db_pool.simple_select_one_onecol( + table="event_relations", + keyvalues={"relates_to_id": parent_id}, + retcol="event_id", + allow_none=True, + desc="event_is_target_of_relation", + ) + return result is not None + @cached(tree=True) async def get_aggregation_groups_for_event( self, @@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore): %s; """ - def _get_if_event_has_relations(txn) -> List[str]: + def _get_if_events_have_relations(txn) -> List[str]: clauses: List[str] = [] clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", parent_ids @@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore): return [row[0] for row in txn] return await self.db_pool.runInteraction( - "get_if_event_has_relations", _get_if_event_has_relations + "get_if_events_have_relations", _get_if_events_have_relations ) async def has_user_annotated_event( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 17b398bb69..7d694d852d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def room_is_blocked_by(self, room_id: str) -> Optional[str]: + """ + Function to retrieve user who has blocked the room. + user_id is non-nullable + It returns None if the room is not blocked. + """ + return await self.db_pool.simple_select_one_onecol( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + allow_none=True, + desc="room_is_blocked_by", + ) + async def get_rooms_paginate( self, start: int, @@ -1775,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.is_room_blocked, (room_id,), ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index 97b2618437..39e80f6f5b 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -39,13 +39,11 @@ class RoomBatchStore(SQLBaseStore): async def store_state_group_id_for_event_id( self, event_id: str, state_group_id: int - ) -> Optional[str]: - { - await self.db_pool.simple_upsert( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - values={"state_group": state_group_id, "event_id": event_id}, - # Unique constraint on event_id so we don't have to lock - lock=False, - ) - } + ) -> None: + await self.db_pool.simple_upsert( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + values={"state_group": state_group_id, "event_id": event_id}, + # Unique constraint on event_id so we don't have to lock + lock=False, + ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 033a9831d6..6b2a8d06a6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND state_key = ? + AND c.state_key = ? AND c.membership = ? """ else: @@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' - AND state_key = ? + AND c.state_key = ? AND m.membership = ? """ diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index ab2159c2d3..3201623fe4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -63,12 +63,12 @@ class SignatureWorkerStore(SQLBaseStore): A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. """ hashes = await self.get_event_reference_hashes(event_ids) - hashes = { + encoded_hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() } - return list(hashes.items()) + return list(encoded_hashes.items()) def _get_event_reference_hashes_txn( self, txn: Cursor, event_id: str diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index a89747d741..188afec332 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -16,11 +16,17 @@ import logging from typing import Any, Dict, List, Tuple from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): + # This class must be mixed in with a child class which provides the following + # attribute. TODO: can we get static analysis to enforce this? + _curr_state_delta_stream_cache: StreamChangeCache + async def get_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: @@ -50,7 +56,9 @@ class StateDeltasStore(SQLBaseStore): prev_stream_id = int(prev_stream_id) # check we're not going backwards - assert prev_stream_id <= max_stream_id + assert ( + prev_stream_id <= max_stream_id + ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id @@ -60,7 +68,9 @@ class StateDeltasStore(SQLBaseStore): # max_stream_id. return max_stream_id, [] - def get_current_state_deltas_txn(txn): + def get_current_state_deltas_txn( + txn: LoggingTransaction, + ) -> Tuple[int, List[Dict[str, Any]]]: # First we calculate the max stream id that will give us less than # N results. # We arbitrarily limit to 100 stream_id entries to ensure we don't @@ -106,7 +116,9 @@ class StateDeltasStore(SQLBaseStore): "get_current_state_deltas", get_current_state_deltas_txn ) - def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + def _get_max_stream_id_in_current_state_deltas_txn( + self, txn: LoggingTransaction + ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", @@ -114,7 +126,7 @@ class StateDeltasStore(SQLBaseStore): retcol="COALESCE(MAX(stream_id), -1)", ) - async def get_max_stream_id_in_current_state_deltas(self): + async def get_max_stream_id_in_current_state_deltas(self) -> int: return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 42dc807d17..57aab55259 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): oldest `limit` events. Returns: - The list of events (in ascending order) and the token from the start + The list of events (in ascending stream order) and the token from the start of the chunk of events returned. """ if from_key == to_key: @@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [], from_key - def f(txn): + def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: + """Fetch membership events for a given user. + + All such events whose stream ordering `s` lies in the range + `from_key < s <= to_key` are returned. Events are ordered by ascending stream + order. + """ + # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if not has_changed: return [] - def f(txn): + def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream @@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): Returns: A list of events and a token pointing to the start of the returned - events. The events returned are in ascending order. + events. The events returned are in ascending topological order. """ rows, token = await self.get_recent_event_ids_for_room( diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index f93ff0a545..8f510de53d 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -1,5 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +15,10 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, cast from synapse.storage._base import db_to_json +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.types import JsonDict from synapse.util import json_encoder @@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: """Get updates for tags replication stream. Args: @@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore): if last_id == current_id: return [], current_id, False - def get_all_updated_tags_txn(txn): + def get_all_updated_tags_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, room_id" " FROM room_tags_revisions as r" @@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore): " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + # mypy doesn't understand what the query is selecting. + return cast(List[Tuple[int, str, str]], txn.fetchall()) tag_ids = await self.db_pool.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) - def get_tag_content(txn, tag_ids): + def get_tag_content( + txn: LoggingTransaction, tag_ids + ) -> List[Tuple[int, Tuple[str, str, str]]]: sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" results = [] for stream_id, user_id, room_id in tag_ids: @@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore): given version Args: - user_id(str): The user to get the tags for. - stream_id(int): The earliest update to get for the user. + user_id: The user to get the tags for. + stream_id: The earliest update to get for the user. Returns: A mapping from room_id strings to lists of tag strings for all the rooms that changed since the stream_id token. """ - def get_updated_tags_txn(txn): + def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]: sql = ( "SELECT room_id from room_tags_revisions" " WHERE user_id = ? AND stream_id > ?" @@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore): content_json = json_encoder.encode(content) - def add_tag_txn(txn, next_id): + def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None: self.db_pool.simple_upsert_txn( txn, table="room_tags", @@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore): """ assert self._can_write_to_account_data - def remove_tag_txn(txn, next_id): + def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( "DELETE FROM room_tags " " WHERE user_id = ? AND room_id = ? AND tag = ?" diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index d7dc1f73ac..1622822552 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,6 +14,7 @@ import logging from collections import namedtuple +from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple( ) +class DestinationSortOrder(Enum): + """Enum to define the sorting method used when returning destinations.""" + + DESTINATION = "destination" + RETRY_LAST_TS = "retry_last_ts" + RETTRY_INTERVAL = "retry_interval" + FAILURE_TS = "failure_ts" + LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destinations = [row[0] for row in txn] return destinations + + async def get_destinations_paginate( + self, + start: int, + limit: int, + destination: Optional[str] = None, + order_by: str = DestinationSortOrder.DESTINATION.value, + direction: str = "f", + ) -> Tuple[List[JsonDict], int]: + """Function to retrieve a paginated list of destinations. + This will return a json list of destinations and the + total number of destinations matching the filter criteria. + + Args: + start: start number to begin the query from + limit: number of rows to retrieve + destination: search string in destination + order_by: the sort order of the returned list + direction: sort ascending or descending + Returns: + A tuple of a list of mappings from destination to information + and a count of total destinations. + """ + + def get_destinations_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + order_by_column = DestinationSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + args = [] + where_statement = "" + if destination: + args.extend(["%" + destination.lower() + "%"]) + where_statement = "WHERE LOWER(destination) LIKE ?" + + sql_base = f"FROM destinations {where_statement} " + sql = f"SELECT COUNT(*) as total_destinations {sql_base}" + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = f""" + SELECT destination, retry_last_ts, retry_interval, failure_ts, + last_successful_stream_ordering + {sql_base} + ORDER BY {order_by_column} {order}, destination ASC + LIMIT ? OFFSET ? + """ + txn.execute(sql, args + [limit, start]) + destinations = self.db_pool.cursor_to_dict(txn) + return destinations, count + + return await self.db_pool.runInteraction( + "get_destinations_paginate_txn", get_destinations_paginate_txn + ) diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 1ecdd40c38..f79006533f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -14,11 +14,12 @@ from typing import Dict, Iterable -from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.util.caches.descriptors import cached, cachedList -class UserErasureWorkerStore(SQLBaseStore): +class UserErasureWorkerStore(CacheInvalidationWorkerStore): @cached() async def is_user_erased(self, user_id: str) -> bool: """ @@ -69,7 +70,7 @@ class UserErasureStore(UserErasureWorkerStore): user_id: full user_id to be erased """ - def f(txn): + def f(txn: LoggingTransaction) -> None: # first check if they are already in the list txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) if txn.fetchone(): @@ -89,7 +90,7 @@ class UserErasureStore(UserErasureWorkerStore): user_id: full user_id to be un-erased """ - def f(txn): + def f(txn: LoggingTransaction) -> None: # first check if they are already in the list txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) if not txn.fetchone(): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 402f134d89..428d66a617 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -583,7 +583,8 @@ class EventsPersistenceStorage: current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, - backfilled=backfilled, + use_negative_stream_ordering=backfilled, + inhibit_local_membership_updates=backfilled, ) await self._handle_potentially_left_users(potentially_left_users) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 8b9c6adae2..e45adfcb55 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -131,24 +131,16 @@ def prepare_database( "config==None in prepare_database, but database is not empty" ) - # if it's a worker app, refuse to upgrade the database, to avoid multiple - # workers doing it at once. - if config.worker.worker_app is None: - _upgrade_existing_database( - cur, - version_info, - database_engine, - config, - databases=databases, - ) - elif version_info.current_version < SCHEMA_VERSION: - # If the DB is on an older version than we expect then we refuse - # to start the worker (as the main process needs to run first to - # update the schema). - raise UpgradeDatabaseException( - OUTDATED_SCHEMA_ON_WORKER_ERROR - % (SCHEMA_VERSION, version_info.current_version) - ) + # This should be run on all processes, master or worker. The master will + # apply the deltas, while workers will check if any outstanding deltas + # exist and raise an PrepareDatabaseException if they do. + _upgrade_existing_database( + cur, + version_info, + database_engine, + config, + databases=databases, + ) else: logger.info("%r: Initialising new database", databases) @@ -358,6 +350,18 @@ def _upgrade_existing_database( is_worker = config and config.worker.worker_app is not None + # If the schema version needs to be updated, and we are on a worker, we immediately + # know to bail out as workers cannot update the database schema. Only one process + # must update the database at the time, therefore we delegate this task to the master. + if is_worker and current_schema_state.current_version < SCHEMA_VERSION: + # If the DB is on an older version than we expect then we refuse + # to start the worker (as the main process needs to run first to + # update the schema). + raise UpgradeDatabaseException( + OUTDATED_SCHEMA_ON_WORKER_ERROR + % (SCHEMA_VERSION, current_schema_state.current_version) + ) + if ( current_schema_state.compat_version is not None and current_schema_state.compat_version > SCHEMA_VERSION diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index a1d2332326..50d08094d5 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 65 # remember to update the list below when updating +SCHEMA_VERSION = 66 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -45,10 +45,17 @@ Changes in SCHEMA_VERSION = 64: Changes in SCHEMA_VERSION = 65: - MSC2716: Remove unique event_id constraint from insertion_event_edges because an insertion event can have multiple edges. + - Remove unused tables `user_stats_historical` and `room_stats_historical`. + +Changes in SCHEMA_VERSION = 66: + - Queries on state_key columns are now disambiguated (ie, the codebase can handle + the `events` table having a `state_key` column). """ -SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata. +SCHEMA_COMPAT_VERSION = ( + 61 # 61: Remove unused tables `user_stats_historical` and `room_stats_historical` +) """Limit on how far the synapse codebase can be rolled back without breaking db compat This value is stored in the database, and checked on startup. If the value in the diff --git a/synapse/storage/schema/main/delta/65/05_remove_room_stats_historical_and_user_stats_historical.sql b/synapse/storage/schema/main/delta/65/05_remove_room_stats_historical_and_user_stats_historical.sql new file mode 100644 index 0000000000..a145180e7a --- /dev/null +++ b/synapse/storage/schema/main/delta/65/05_remove_room_stats_historical_and_user_stats_historical.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Remove unused tables room_stats_historical and user_stats_historical + -- which have not been read or written since schema version 61. + DROP TABLE IF EXISTS room_stats_historical; + DROP TABLE IF EXISTS user_stats_historical; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql new file mode 100644 index 0000000000..82f6408b36 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql @@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Remove messages from the device_inbox table which were orphaned +-- when a device was deleted using Synapse earlier than 1.47.0. +-- This runs as background task, but may take a bit to finish. + +-- Remove any existing instances of this job running. It's OK to stop and restart this job, +-- as it's just deleting entries from a table - no progress will be lost. +-- +-- This is necessary due a similar migration running the job accidentally +-- being included in schema version 64 during v1.47.0rc1,rc2. If a +-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64), +-- then they would have started running this background update already. +-- If that update was still running, then simply inserting it again would +-- cause an SQL failure. So we effectively do an "upsert" here instead. + +DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox'; + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6506, 'remove_deleted_devices_from_device_inbox', '{}'); diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql index d60517f7b4..267b2cb539 100644 --- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql +++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql @@ -15,4 +15,4 @@ -- Check old events for thread relations. INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (6502, 'event_thread_relation', '{}'); + (6507, 'event_arbitrary_relations', '{}'); diff --git a/synapse/storage/schema/main/delta/64/02remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql index fca7290741..d79455c2ce 100644 --- a/synapse/storage/schema/main/delta/64/02remove_deleted_devices_from_device_inbox.sql +++ b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql @@ -13,10 +13,6 @@ * limitations under the License. */ - --- Remove messages from the device_inbox table which were orphaned --- when a device was deleted using Synapse earlier than 1.47.0. --- This runs as background task, but may take a bit to finish. - +-- Background update to clear the inboxes of hidden and deleted devices. INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (6402, 'remove_deleted_devices_from_device_inbox', '{}'); + (6508, 'remove_dead_devices_from_device_inbox', '{}'); diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql new file mode 100644 index 0000000000..bdc491c817 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql @@ -0,0 +1,28 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE refresh_tokens + -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens. + -- They may not be used after they have expired. + -- If null, then the refresh token's lifetime is unlimited. + ADD COLUMN expiry_ts BIGINT DEFAULT NULL; + +ALTER TABLE refresh_tokens + -- We also add an ultimate session expiry time (in milliseconds since the Epoch). + -- No matter how much the access and refresh tokens are refreshed, they cannot + -- be extended past this time. + -- If null, then the session length is unlimited. + ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL; diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql new file mode 100644 index 0000000000..a65bfb520d --- /dev/null +++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql @@ -0,0 +1,27 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track the auth provider used by each login as well as the session ID +CREATE TABLE device_auth_providers ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + auth_provider_id TEXT NOT NULL, + auth_provider_session_id TEXT NOT NULL +); + +CREATE INDEX device_auth_providers_devices + ON device_auth_providers (user_id, device_id); +CREATE INDEX device_auth_providers_sessions + ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 670811611f..b8112e1c05 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import heapq import logging import threading @@ -72,8 +74,6 @@ class IdGenerator: def _load_current_id( db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 ) -> int: - # debug logging for https://github.com/matrix-org/synapse/issues/7968 - logger.info("initialising stream generator for %s(%s)", table, column) cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) @@ -84,16 +84,82 @@ def _load_current_id( (val,) = result cur.close() current_id = int(val) if val else step - return (max if step > 0 else min)(current_id, step) + res = (max if step > 0 else min)(current_id, step) + logger.info("Initialising stream generator for %s(%s): %i", table, column, res) + return res -class StreamIdGenerator: - """Used to generate new stream ids when persisting events while keeping - track of which transactions have been completed. +class AbstractStreamIdTracker(metaclass=abc.ABCMeta): + """Tracks the "current" stream ID of a stream that may have multiple writers. - This allows us to get the "current" stream id, i.e. the stream id such that - all ids less than or equal to it have completed. This handles the fact that - persistence of events can complete out of order. + Stream IDs are monotonically increasing or decreasing integers representing write + transactions. The "current" stream ID is the stream ID such that all transactions + with equal or smaller stream IDs have completed. Since transactions may complete out + of order, this is not the same as the stream ID of the last completed transaction. + + Completed transactions include both committed transactions and transactions that + have been rolled back. + """ + + @abc.abstractmethod + def advance(self, instance_name: str, new_id: int) -> None: + """Advance the position of the named writer to the given ID, if greater + than existing entry. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + + Returns: + The maximum stream id. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to `get_current_token`. + """ + raise NotImplementedError() + + +class AbstractStreamIdGenerator(AbstractStreamIdTracker): + """Generates stream IDs for a stream that may have multiple writers. + + Each stream ID represents a write transaction, whose completion is tracked + so that the "current" stream ID of the stream can be determined. + + See `AbstractStreamIdTracker` for more details. + """ + + @abc.abstractmethod + def get_next(self) -> AsyncContextManager[int]: + """ + Usage: + async with stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: + """ + Usage: + async with stream_id_gen.get_next(n) as stream_ids: + # ... persist events ... + """ + raise NotImplementedError() + + +class StreamIdGenerator(AbstractStreamIdGenerator): + """Generates and tracks stream IDs for a stream with a single writer. + + This class must only be used when the current Synapse process is the sole + writer for a stream. Args: db_conn(connection): A database connection to use to fetch the @@ -137,12 +203,12 @@ class StreamIdGenerator: # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() + def advance(self, instance_name: str, new_id: int) -> None: + # `StreamIdGenerator` should only be used when there is a single writer, + # so replication should never happen. + raise Exception("Replication is not supported by StreamIdGenerator") + def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ with self._lock: self._current += self._step next_id = self._current @@ -160,11 +226,6 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - """ - Usage: - async with stream_id_gen.get_next(n) as stream_ids: - # ... persist events ... - """ with self._lock: next_ids = range( self._current + self._step, @@ -188,12 +249,6 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - - Returns: - The maximum stream id. - """ with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step @@ -201,16 +256,11 @@ class StreamIdGenerator: return self._current def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - - For streams with single writers this is equivalent to - `get_current_token`. - """ return self.get_current_token() -class MultiWriterIdGenerator: - """An ID generator that tracks a stream that can have multiple writers. +class MultiWriterIdGenerator(AbstractStreamIdGenerator): + """Generates and tracks stream IDs for a stream with multiple writers. Uses a Postgres sequence to coordinate ID assignment, but positions of other writers will only get updated when `advance` is called (by replication). @@ -455,12 +505,6 @@ class MultiWriterIdGenerator: return stream_ids def get_next(self) -> AsyncContextManager[int]: - """ - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ - # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -472,12 +516,6 @@ class MultiWriterIdGenerator: return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: - """ - Usage: - async with stream_id_gen.get_next_mult(5) as stream_ids: - # ... persist events ... - """ - # If we have a list of instances that are allowed to write to this # stream, make sure we're in it. if self._writers and self._instance_name not in self._writers: @@ -577,15 +615,9 @@ class MultiWriterIdGenerator: self._add_persisted_position(next_id) def get_current_token(self) -> int: - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - """ - return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer.""" - # If we don't have an entry for the given instance name, we assume it's a # new writer. # @@ -611,10 +643,6 @@ class MultiWriterIdGenerator: } def advance(self, instance_name: str, new_id: int) -> None: - """Advance the position of the named writer to the given ID, if greater - than existing entry. - """ - new_id *= self._return_factor with self._lock: diff --git a/synapse/types.py b/synapse/types.py index 364ecf7d45..fb72f19343 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -19,6 +19,7 @@ from collections import namedtuple from typing import ( TYPE_CHECKING, Any, + ClassVar, Dict, Mapping, MutableMapping, @@ -38,6 +39,7 @@ from zope.interface import Interface from twisted.internet.interfaces import ( IReactorCore, IReactorPluggableNameResolver, + IReactorSSL, IReactorTCP, IReactorThreads, IReactorTime, @@ -66,6 +68,7 @@ JsonDict = Dict[str, Any] # for mypy-zope to realize it is an interface. class ISynapseReactor( IReactorTCP, + IReactorSSL, IReactorPluggableNameResolver, IReactorTime, IReactorCore, @@ -217,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): 'domain' : The domain part of the name """ - SIGIL: str = abc.abstractproperty() # type: ignore + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore localpart = attr.ib(type=str) domain = attr.ib(type=str) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index abf53d149d..95f23e27b6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -16,7 +16,7 @@ import json import logging import re import typing -from typing import Any, Callable, Dict, Generator, Pattern +from typing import Any, Callable, Dict, Generator, Optional, Pattern import attr from frozendict import frozendict @@ -110,7 +110,9 @@ class Clock: """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) - def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall: + def looping_call( + self, f: Callable, msec: float, *args: Any, **kwargs: Any + ) -> LoopingCall: """Call a function repeatedly. Waits `msec` initially before calling `f` for the first time. @@ -130,20 +132,22 @@ class Clock: d.addErrback(log_failure, "Looping call died", consumeErrors=False) return call - def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall: + def call_later( + self, delay: float, callback: Callable, *args: Any, **kwargs: Any + ) -> IDelayedCall: """Call something later Note that the function will be called with no logcontext, so if it is anything other than trivial, you probably want to wrap it in run_as_background_process. Args: - delay(float): How long to wait in seconds. - callback(function): Function to call + delay: How long to wait in seconds. + callback: Function to call *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ - def wrapped_callback(*args, **kwargs): + def wrapped_callback(*args: Any, **kwargs: Any) -> None: with context.PreserveLoggingContext(): callback(*args, **kwargs) @@ -158,25 +162,29 @@ class Clock: raise -def log_failure(failure, msg, consumeErrors=True): +def log_failure( + failure: Failure, msg: str, consumeErrors: bool = True +) -> Optional[Failure]: """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. Args: - msg (str): Message to log - consumeErrors (bool): If true consumes the failure, otherwise passes - on down the callback chain + failure: The Failure to log + msg: Message to log + consumeErrors: If true consumes the failure, otherwise passes on down + the callback chain Returns: - func(Failure) + The Failure if consumeErrors is false. None, otherwise. """ logger.error( - msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) + msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) # type: ignore[arg-type] ) if not consumeErrors: return failure + return None def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 96efc5f3e3..20ce294209 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -27,20 +27,20 @@ from typing import ( Generic, Hashable, Iterable, + Iterator, Optional, Set, TypeVar, Union, + cast, ) import attr from typing_extensions import ContextManager from twisted.internet import defer -from twisted.internet.base import ReactorBase from twisted.internet.defer import CancelledError from twisted.internet.interfaces import IReactorTime -from twisted.python import failure from twisted.python.failure import Failure from synapse.logging.context import ( @@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]): object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", []) - def callback(r): + def callback(r: _T) -> _T: object.__setattr__(self, "_result", (True, r)) # once we have set _result, no more entries will be added to _observers, @@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]): ) return r - def errback(f): + def errback(f: Failure) -> Optional[Failure]: object.__setattr__(self, "_result", (False, f)) # once we have set _result, no more entries will be added to _observers, @@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]): for observer in observers: # This is a little bit of magic to correctly propagate stack # traces when we `await` on one of the observer deferreds. - f.value.__failure__ = f + f.value.__failure__ = f # type: ignore[union-attr] try: observer.errback(f) except Exception as e: @@ -271,8 +271,7 @@ class Linearizer: if not clock: from twisted.internet import reactor - assert isinstance(reactor, ReactorBase) - clock = Clock(reactor) + clock = Clock(cast(IReactorTime, reactor)) self._clock = clock self.max_count = max_count @@ -315,7 +314,7 @@ class Linearizer: # will release the lock. @contextmanager - def _ctx_manager(_): + def _ctx_manager(_: None) -> Iterator[None]: try: yield finally: @@ -356,7 +355,7 @@ class Linearizer: new_defer = make_deferred_yieldable(defer.Deferred()) entry.deferreds[new_defer] = 1 - def cb(_r): + def cb(_r: None) -> "defer.Deferred[None]": logger.debug("Acquired linearizer lock %r for key %r", self.name, key) entry.count += 1 @@ -372,7 +371,7 @@ class Linearizer: # code must be synchronous, so this is the only sensible place.) return self._clock.sleep(0) - def eb(e): + def eb(e: Failure) -> Failure: logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( @@ -436,7 +435,7 @@ class ReadWriteLock: await make_deferred_yieldable(curr_writer) @contextmanager - def _ctx_manager(): + def _ctx_manager() -> Iterator[None]: try: yield finally: @@ -465,7 +464,7 @@ class ReadWriteLock: await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager - def _ctx_manager(): + def _ctx_manager() -> Iterator[None]: try: yield finally: @@ -525,7 +524,7 @@ def timeout_deferred( delayed_call = reactor.callLater(timeout, time_it_out) - def convert_cancelled(value: failure.Failure): + def convert_cancelled(value: Failure) -> Failure: # if the original deferred was cancelled, and our timeout has fired, then # the reason it was cancelled was due to our timeout. Turn the CancelledError # into a TimeoutError. @@ -535,7 +534,7 @@ def timeout_deferred( deferred.addErrback(convert_cancelled) - def cancel_timeout(result): + def cancel_timeout(result: _T) -> _T: # stop the pending call to cancel the deferred if it's been fired if delayed_call.active(): delayed_call.cancel() @@ -543,11 +542,11 @@ def timeout_deferred( deferred.addBoth(cancel_timeout) - def success_cb(val): + def success_cb(val: _T) -> None: if not new_d.called: new_d.callback(val) - def failure_cb(val): + def failure_cb(val: Failure) -> None: if not new_d.called: new_d.errback(val) @@ -558,13 +557,13 @@ def timeout_deferred( # This class can't be generic because it uses slots with attrs. # See: https://github.com/python-attrs/attrs/issues/313 -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class DoneAwaitable: # should be: Generic[R] """Simple awaitable that returns the provided value.""" - value = attr.ib(type=Any) # should be: R + value: Any # should be: R - def __await__(self): + def __await__(self) -> Any: return self def __iter__(self) -> "DoneAwaitable": diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index df4d61e4b6..15debd6c46 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -17,7 +17,7 @@ import logging import typing from enum import Enum, auto from sys import intern -from typing import Callable, Dict, Optional, Sized +from typing import Any, Callable, Dict, List, Optional, Sized import attr from prometheus_client.core import Gauge @@ -58,20 +58,20 @@ class EvictionReason(Enum): time = auto() -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache = attr.ib() - _cache_type = attr.ib(type=str) - _cache_name = attr.ib(type=str) - _collect_callback = attr.ib(type=Optional[Callable]) + _cache: Sized + _cache_type: str + _cache_name: str + _collect_callback: Optional[Callable] - hits = attr.ib(default=0) - misses = attr.ib(default=0) + hits: int = 0 + misses: int = 0 eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib( factory=collections.Counter ) - memory_usage = attr.ib(default=None) + memory_usage: Optional[int] = None def inc_hits(self) -> None: self.hits += 1 @@ -89,13 +89,14 @@ class CacheMetric: self.memory_usage += memory def dec_memory_usage(self, memory: int) -> None: + assert self.memory_usage is not None self.memory_usage -= memory def clear_memory_usage(self) -> None: if self.memory_usage is not None: self.memory_usage = 0 - def describe(self): + def describe(self) -> List[str]: return [] def collect(self) -> None: @@ -118,8 +119,9 @@ class CacheMetric: self.eviction_size_by_reason[reason] ) cache_total.labels(self._cache_name).set(self.hits + self.misses) - if getattr(self._cache, "max_size", None): - cache_max_size.labels(self._cache_name).set(self._cache.max_size) + max_size = getattr(self._cache, "max_size", None) + if max_size: + cache_max_size.labels(self._cache_name).set(max_size) if TRACK_MEMORY_USAGE: # self.memory_usage can be None if nothing has been inserted @@ -193,7 +195,7 @@ KNOWN_KEYS = { } -def intern_string(string): +def intern_string(string: Optional[str]) -> Optional[str]: """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -204,7 +206,7 @@ def intern_string(string): return string -def intern_dict(dictionary): +def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]: """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) @@ -212,7 +214,7 @@ def intern_dict(dictionary): } -def _intern_known_values(key, value): +def _intern_known_values(key: str, value: Any) -> Any: intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index da502aec11..377c9a282a 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -22,6 +22,7 @@ from typing import ( Iterable, MutableMapping, Optional, + Sized, TypeVar, Union, cast, @@ -104,7 +105,13 @@ class DeferredCache(Generic[KT, VT]): max_size=max_entries, cache_name=name, cache_type=cache_type, - size_callback=(lambda d: len(d) or 1) if iterable else None, + size_callback=( + (lambda d: len(cast(Sized, d)) or 1) + # Argument 1 to "len" has incompatible type "VT"; expected "Sized" + # We trust that `VT` is `Sized` when `iterable` is `True` + if iterable + else None + ), metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, prune_unread_entries=prune_unread_entries, @@ -289,7 +296,7 @@ class DeferredCache(Generic[KT, VT]): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) - def invalidate(self, key) -> None: + def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries If the cache is backed by a regular dict, then "key" must be of diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b9dcca17f1..375cd443f1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -19,12 +19,15 @@ import logging from typing import ( Any, Callable, + Dict, Generic, + Hashable, Iterable, Mapping, Optional, Sequence, Tuple, + Type, TypeVar, Union, cast, @@ -32,6 +35,7 @@ from typing import ( from weakref import WeakValueDictionary from twisted.internet import defer +from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError @@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): + def __init__( + self, + orig: Callable[..., Any], + num_args: Optional[int], + cache_context: bool = False, + ): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __init__( self, - orig, + orig: Callable[..., Any], max_entries: int = 1000, cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries - def __get__(self, obj, owner): + def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( cache_name=self.orig.__name__, max_size=self.max_entries, @@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): sentinel = LruCacheDescriptor._Sentinel.sentinel @functools.wraps(self.orig) - def _wrapped(*args, **kwargs): + def _wrapped(*args: Any, **kwargs: Any) -> Any: invalidate_callback = kwargs.pop("on_invalidate", None) callbacks = (invalidate_callback,) if invalidate_callback else () @@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return r1 + r2 Args: - num_args (int): number of positional arguments (excluding ``self`` and + num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ def __init__( self, - orig, - max_entries=1000, - num_args=None, - tree=False, - cache_context=False, - iterable=False, + orig: Callable[..., Any], + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, prune_unread_entries: bool = True, ): super().__init__(orig, num_args=num_args, cache_context=cache_context) @@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.iterable = iterable self.prune_unread_entries = prune_unread_entries - def __get__(self, obj, owner): + def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, @@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): get_cache_key = self.cache_key_builder @functools.wraps(self.orig) - def _wrapped(*args, **kwargs): + def _wrapped(*args: Any, **kwargs: Any) -> Any: # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None): + def __init__( + self, + orig: Callable[..., Any], + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + ): """ Args: - orig (function) - cached_method_name (str): The name of the cached method. - list_name (str): Name of the argument which is the bulk lookup list - num_args (int): number of positional arguments (excluding ``self``, + orig + cached_method_name: The name of the cached method. + list_name: Name of the argument which is the bulk lookup list + num_args: number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. """ @@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): % (self.list_name, cached_method_name) ) - def __get__(self, obj, objtype=None): + def __get__( + self, obj: Optional[Any], objtype: Optional[Type] = None + ) -> Callable[..., Any]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): results = {} - def update_results_dict(res, arg): + def update_results_dict(res: Any, arg: Hashable) -> None: results[arg] = res # list of deferreds to wait for @@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): # otherwise a tuple is used. if num_args == 1: - def arg_to_cache_key(arg): + def arg_to_cache_key(arg: Hashable) -> Hashable: return arg else: keylist = list(keyargs) - def arg_to_cache_key(arg): + def arg_to_cache_key(arg: Hashable) -> Hashable: keylist[self.list_pos] = arg return tuple(keylist) @@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): key = arg_to_cache_key(arg) cache.set(key, deferred, callback=invalidate_callback) - def complete_all(res): + def complete_all(res: Dict[Hashable, Any]) -> None: # the wrapped function has completed. It returns a # a dict. We can now resolve the observable deferreds in # the cache and update our own result map. @@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferreds_map[e].callback(val) results[e] = val - def errback(f): + def errback(f: Failure) -> Failure: # the wrapped function has failed. Invalidate any cache # entries we're supposed to be populating, and fail # their deferreds. diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index c3f72aa06d..67ee4c693b 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload import attr from typing_extensions import Literal +from twisted.internet import defer + from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util import Clock @@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]): # Don't bother starting the loop if things never expire return - def f(): + def f() -> "defer.Deferred[None]": return run_as_background_process( "prune_cache_%s" % self._cache_name, self._prune_cache ) @@ -157,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]): self[key] = value return value - def _prune_cache(self) -> None: + async def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. @@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]): return False -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _CacheEntry: - time = attr.ib(type=int) - value = attr.ib() + time: int + value: Any diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a0a7a9de32..eb96f7e665 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -15,14 +15,15 @@ import logging import threading import weakref +from enum import Enum from functools import wraps from typing import ( TYPE_CHECKING, Any, Callable, Collection, + Dict, Generic, - Iterable, List, Optional, Type, @@ -190,7 +191,7 @@ class _Node(Generic[KT, VT]): root: "ListNode[_Node]", key: KT, value: VT, - cache: "weakref.ReferenceType[LruCache]", + cache: "weakref.ReferenceType[LruCache[KT, VT]]", clock: Clock, callbacks: Collection[Callable[[], None]] = (), prune_unread_entries: bool = True, @@ -270,7 +271,10 @@ class _Node(Generic[KT, VT]): removed from all lists. """ cache = self._cache() - if not cache or not cache.pop(self.key, None): + if ( + cache is None + or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel + ): # `cache.pop` should call `drop_from_lists()`, unless this Node had # already been removed from the cache. self.drop_from_lists() @@ -290,6 +294,12 @@ class _Node(Generic[KT, VT]): self._global_list_node.update_last_access(clock) +class _Sentinel(Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. @@ -302,7 +312,7 @@ class LruCache(Generic[KT, VT]): max_size: int, cache_name: Optional[str] = None, cache_type: Type[Union[dict, TreeCache]] = dict, - size_callback: Optional[Callable] = None, + size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, clock: Optional[Clock] = None, @@ -339,7 +349,7 @@ class LruCache(Generic[KT, VT]): else: real_clock = clock - cache = cache_type() + cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -374,7 +384,7 @@ class LruCache(Generic[KT, VT]): # creating more each time we create a `_Node`. weak_ref_to_self = weakref.ref(self) - list_root = ListNode[_Node].create_root_node() + list_root = ListNode[_Node[KT, VT]].create_root_node() lock = threading.Lock() @@ -422,7 +432,7 @@ class LruCache(Generic[KT, VT]): def add_node( key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () ) -> None: - node = _Node( + node: _Node[KT, VT] = _Node( list_root, key, value, @@ -439,10 +449,10 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node: _Node) -> None: + def move_node_to_front(node: _Node[KT, VT]) -> None: node.move_to_front(real_clock, list_root) - def delete_node(node: _Node) -> int: + def delete_node(node: _Node[KT, VT]) -> int: node.drop_from_lists() deleted_len = 1 @@ -496,7 +506,7 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_set( - key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () + key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () ) -> None: node = cache.get(key, None) if node is not None: @@ -590,8 +600,6 @@ class LruCache(Generic[KT, VT]): def cache_contains(key: KT) -> bool: return key in cache - self.sentinel = object() - # make sure that we clear out any excess entries after we get resized. self._on_resize = evict @@ -608,18 +616,18 @@ class LruCache(Generic[KT, VT]): self.clear = cache_clear def __getitem__(self, key: KT) -> VT: - result = self.get(key, self.sentinel) - if result is self.sentinel: + result = self.get(key, _Sentinel.sentinel) + if result is _Sentinel.sentinel: raise KeyError() else: - return cast(VT, result) + return result def __setitem__(self, key: KT, value: VT) -> None: self.set(key, value) def __delitem__(self, key: KT, value: VT) -> None: - result = self.pop(key, self.sentinel) - if result is self.sentinel: + result = self.pop(key, _Sentinel.sentinel) + if result is _Sentinel.sentinel: raise KeyError() def __len__(self) -> int: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 31097d6439..91837655f8 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -18,12 +18,13 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) -def user_left_room(distributor, user, room_id): +def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None: distributor.fire("user_left_room", user=user, room_id=room_id) @@ -63,7 +64,7 @@ class Distributor: self.pre_registration[name] = [] self.pre_registration[name].append(observer) - def fire(self, name: str, *args, **kwargs) -> None: + def fire(self, name: str, *args: Any, **kwargs: Any) -> None: """Dispatches the given signal to the registered observers. Runs the observers as a background process. Does not return a deferred. @@ -95,7 +96,7 @@ class Signal: Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]": + def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. @@ -103,7 +104,7 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - async def do(observer): + async def do(observer: Callable[..., Any]) -> Any: try: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: @@ -120,5 +121,5 @@ class Signal: defer.gatherResults(deferreds, consumeErrors=True) ) - def __repr__(self): + def __repr__(self) -> str: return "<Signal name=%r>" % (self.name,) diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py index a447ce4e55..214eb17fbc 100644 --- a/synapse/util/gai_resolver.py +++ b/synapse/util/gai_resolver.py @@ -3,23 +3,52 @@ # We copy it here as we need to instantiate `GAIResolver` manually, but it is a # private class. - from socket import ( AF_INET, AF_INET6, AF_UNSPEC, SOCK_DGRAM, SOCK_STREAM, + AddressFamily, + SocketKind, gaierror, getaddrinfo, ) +from typing import ( + TYPE_CHECKING, + Callable, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) from zope.interface import implementer from twisted.internet.address import IPv4Address, IPv6Address -from twisted.internet.interfaces import IHostnameResolver, IHostResolution +from twisted.internet.interfaces import ( + IAddress, + IHostnameResolver, + IHostResolution, + IReactorThreads, + IResolutionReceiver, +) from twisted.internet.threads import deferToThreadPool +if TYPE_CHECKING: + # The types below are copied from + # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py + # so that the type hints can match the interfaces. + from twisted.python.runtime import platform + + if platform.supportsThreads(): + from twisted.python.threadpool import ThreadPool + else: + ThreadPool = object # type: ignore[misc, assignment] + @implementer(IHostResolution) class HostResolution: @@ -27,13 +56,13 @@ class HostResolution: The in-progress resolution of a given hostname. """ - def __init__(self, name): + def __init__(self, name: str): """ Create a L{HostResolution} with the given name. """ self.name = name - def cancel(self): + def cancel(self) -> NoReturn: # IHostResolution.cancel raise NotImplementedError() @@ -62,6 +91,17 @@ _socktypeToType = { } +_GETADDRINFO_RESULT = List[ + Tuple[ + AddressFamily, + SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] +] + + @implementer(IHostnameResolver) class GAIResolver: """ @@ -69,7 +109,12 @@ class GAIResolver: L{getaddrinfo} in a thread. """ - def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): + def __init__( + self, + reactor: IReactorThreads, + getThreadPool: Optional[Callable[[], "ThreadPool"]] = None, + getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo, + ): """ Create a L{GAIResolver}. @param reactor: the reactor to schedule result-delivery on @@ -89,14 +134,16 @@ class GAIResolver: ) self._getaddrinfo = getaddrinfo - def resolveHostName( + # The types on IHostnameResolver is incorrect in Twisted, see + # https://twistedmatrix.com/trac/ticket/10276 + def resolveHostName( # type: ignore[override] self, - resolutionReceiver, - hostName, - portNumber=0, - addressTypes=None, - transportSemantics="TCP", - ): + resolutionReceiver: IResolutionReceiver, + hostName: str, + portNumber: int = 0, + addressTypes: Optional[Sequence[Type[IAddress]]] = None, + transportSemantics: str = "TCP", + ) -> IHostResolution: """ See L{IHostnameResolver.resolveHostName} @param resolutionReceiver: see interface @@ -112,7 +159,7 @@ class GAIResolver: ] socketType = _transportToSocket[transportSemantics] - def get(): + def get() -> _GETADDRINFO_RESULT: try: return self._getaddrinfo( hostName, portNumber, addressFamily, socketType @@ -125,7 +172,7 @@ class GAIResolver: resolutionReceiver.resolutionBegan(resolution) @d.addCallback - def deliverResults(result): + def deliverResults(result: _GETADDRINFO_RESULT) -> None: for family, socktype, _proto, _cannoname, sockaddr in result: addrType = _afToType[family] resolutionReceiver.addressResolved( diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index b163643ca3..a0606851f7 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -92,9 +92,9 @@ def _resource_id(resource: Resource, path_seg: bytes) -> str: the mapping should looks like _resource_id(A,C) = B. Args: - resource (Resource): The *parent* Resourceb - path_seg (str): The name of the child Resource to be attached. + resource: The *parent* Resourceb + path_seg: The name of the child Resource to be attached. Returns: - str: A unique string which can be a key to the child Resource. + A unique string which can be a key to the child Resource. """ return "%s-%r" % (resource, path_seg) diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index 9f4be757ba..8efbf061aa 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -84,7 +84,7 @@ class ListNode(Generic[P]): # immediately rather than at the next GC. self.cache_entry = None - def move_after(self, node: "ListNode") -> None: + def move_after(self, node: "ListNode[P]") -> None: """Move this node from its current location in the list to after the given node. """ @@ -122,7 +122,7 @@ class ListNode(Generic[P]): self.prev_node = None self.next_node = None - def _refs_insert_after(self, node: "ListNode") -> None: + def _refs_insert_after(self, node: "ListNode[P]") -> None: """Internal method to insert the node after the given node.""" # This method should only be called when we're not already in the list. diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index f8b2d7bea9..48b8195ca1 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal from twisted.internet import defer -from twisted.internet.protocol import Factory +from twisted.internet.protocol import ServerFactory from synapse.config.server import ManholeConfig @@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: +def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory: """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with @@ -105,7 +105,8 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory: factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment] factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment] - return factory + # ConchFactory is a Factory, not a ServerFactory, but they are identical. + return factory # type: ignore[return-value] class SynapseManhole(ColoredManhole): diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 1e784b3f1f..98ee49af6e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -56,14 +56,22 @@ block_db_sched_duration = Counter( "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] ) + +# This is dynamically created in InFlightGauge.__init__. +class _InFlightMetric(Protocol): + real_time_max: float + real_time_sum: float + + # Tracks the number of blocks currently active -in_flight = InFlightGauge( +in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( "synapse_util_metrics_block_in_flight", "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) + T = TypeVar("T", bound=Callable[..., Any]) @@ -180,7 +188,7 @@ class Measure: """ return self._logging_context.get_resource_usage() - def _update_in_flight(self, metrics) -> None: + def _update_in_flight(self, metrics: _InFlightMetric) -> None: """Gets called when processing in flight metrics""" assert self.start is not None duration = self.clock.time() - self.start diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index f029432191..ea1032b4fc 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,6 +19,8 @@ import string from collections.abc import Iterable from typing import Optional, Tuple +from netaddr import valid_ipv6 + from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -97,7 +99,10 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: raise ValueError("Invalid server name '%s'" % server_name) -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") +# An approximation of the domain name syntax in RFC 1035, section 2.3.1. +# NB: "\Z" is not equivalent to "$". +# The latter will match the position before a "\n" at the end of a string. +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z") def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: @@ -122,13 +127,15 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int] if host[0] == "[": if host[-1] != "]": raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) + # valid_ipv6 raises when given an empty string + ipv6_address = host[1:-1] + if not ipv6_address or not valid_ipv6(ipv6_address): + raise ValueError( + "Server name '%s' is not a valid IPv6 address" % (server_name,) + ) + elif not VALID_HOST_REGEX.match(host): + raise ValueError("Server name '%s' has an invalid format" % (server_name,)) return host, port diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 899ee0adc8..c144ff62c1 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,10 +30,11 @@ def get_version_string(module: ModuleType) -> str: If called on a module not in a git checkout will return `__version__`. Args: - module (module) + module: The module to check the version of. Must declare a __version__ + attribute. Returns: - str + The module version (as a string). """ cached_version = version_cache.get(module) @@ -44,71 +46,37 @@ def get_version_string(module: ModuleType) -> str: version_string = module.__version__ # type: ignore[attr-defined] try: - null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) - try: - git_branch = ( - subprocess.check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd + def _run_git_command(prefix: str, *params: str) -> str: + try: + result = ( + subprocess.check_output( + ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd + ) + .strip() + .decode("ascii") ) - .strip() - .decode("ascii") - ) - git_branch = "b=" + git_branch - except (subprocess.CalledProcessError, FileNotFoundError): - # FileNotFoundError can arise when git is not installed - git_branch = "" - - try: - git_tag = ( - subprocess.check_output( - ["git", "describe", "--exact-match"], stderr=null, cwd=cwd - ) - .strip() - .decode("ascii") - ) - git_tag = "t=" + git_tag - except (subprocess.CalledProcessError, FileNotFoundError): - git_tag = "" - - try: - git_commit = ( - subprocess.check_output( - ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd - ) - .strip() - .decode("ascii") - ) - except (subprocess.CalledProcessError, FileNotFoundError): - git_commit = "" - - try: - dirty_string = "-this_is_a_dirty_checkout" - is_dirty = ( - subprocess.check_output( - ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd - ) - .strip() - .decode("ascii") - .endswith(dirty_string) - ) + return prefix + result + except (subprocess.CalledProcessError, FileNotFoundError): + return "" - git_dirty = "dirty" if is_dirty else "" - except (subprocess.CalledProcessError, FileNotFoundError): - git_dirty = "" + git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD") + git_tag = _run_git_command("t=", "describe", "--exact-match") + git_commit = _run_git_command("", "rev-parse", "--short", "HEAD") + + dirty_string = "-this_is_a_dirty_checkout" + is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith( + dirty_string + ) + git_dirty = "dirty" if is_dirty else "" if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - version_string = "%s (%s)" % ( - # If the __version__ attribute doesn't exist, we'll have failed - # loudly above. - module.__version__, # type: ignore[attr-defined] - git_version, - ) + version_string = f"{version_string} ({git_version})" except Exception as e: logger.info("Failed to check for git repository: %s", e) diff --git a/synctl b/synctl index 90559ded62..0e54f4847b 100755 --- a/synctl +++ b/synctl @@ -24,7 +24,7 @@ import signal import subprocess import sys import time -from typing import Iterable +from typing import Iterable, Optional import yaml @@ -41,11 +41,24 @@ NORMAL = "\x1b[m" def pid_running(pid): try: os.kill(pid, 0) - return True except OSError as err: if err.errno == errno.EPERM: - return True - return False + pass # process exists + else: + return False + + # When running in a container, orphan processes may not get reaped and their + # PIDs may remain valid. Try to work around the issue. + try: + with open(f"/proc/{pid}/status") as status_file: + if "zombie" in status_file.read(): + return False + except Exception: + # This isn't Linux or `/proc/` is unavailable. + # Assume that the process is still running. + pass + + return True def write(message, colour=NORMAL, stream=sys.stdout): @@ -109,15 +122,14 @@ def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) return False -def stop(pidfile: str, app: str) -> bool: +def stop(pidfile: str, app: str) -> Optional[int]: """Attempts to kill a synapse worker from the pidfile. Args: pidfile: path to file containing worker's pid app: name of the worker's appservice Returns: - True if the process stopped successfully - False if process was already stopped or an error occured + process id, or None if the process was not running """ if os.path.exists(pidfile): @@ -125,7 +137,7 @@ def stop(pidfile: str, app: str) -> bool: try: os.kill(pid, signal.SIGTERM) write("stopped %s" % (app,), colour=GREEN) - return True + return pid except OSError as err: if err.errno == errno.ESRCH: write("%s not running" % (app,), colour=YELLOW) @@ -133,14 +145,13 @@ def stop(pidfile: str, app: str) -> bool: abort("Cannot stop %s: Operation not permitted" % (app,)) else: abort("Cannot stop %s: Unknown error" % (app,)) - return False else: write( "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)" % (app, pidfile), colour=YELLOW, ) - return False + return None Worker = collections.namedtuple( @@ -288,32 +299,23 @@ def main(): action = options.action if action == "stop" or action == "restart": - has_stopped = True + running_pids = [] for worker in workers: - if not stop(worker.pidfile, worker.app): - # A worker could not be stopped. - has_stopped = False + pid = stop(worker.pidfile, worker.app) + if pid is not None: + running_pids.append(pid) if start_stop_synapse: - if not stop(pidfile, MAIN_PROCESS): - has_stopped = False - if not has_stopped and action == "stop": - sys.exit(1) + pid = stop(pidfile, MAIN_PROCESS) + if pid is not None: + running_pids.append(pid) - # Wait for synapse to actually shutdown before starting it again - if action == "restart": - running_pids = [] - if start_stop_synapse and os.path.exists(pidfile): - running_pids.append(int(open(pidfile).read())) - for worker in workers: - if os.path.exists(worker.pidfile): - running_pids.append(int(open(worker.pidfile).read())) if len(running_pids) > 0: - write("Waiting for process to exit before restarting...") + write("Waiting for processes to exit...") for running_pid in running_pids: while pid_running(running_pid): time.sleep(0.2) - write("All processes exited; now restarting...") + write("All processes exited") if action == "start" or action == "restart": error = False diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py new file mode 100644 index 0000000000..cbcada0451 --- /dev/null +++ b/tests/app/test_homeserver_start.py @@ -0,0 +1,31 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.app.homeserver +from synapse.config._base import ConfigError + +from tests.config.utils import ConfigFileTestCase + + +class HomeserverAppStartTestCase(ConfigFileTestCase): + def test_wrong_start_caught(self): + # Generate a config with a worker_app + self.generate_config() + # Add a blank line as otherwise the next addition ends up on a line with a comment + self.add_lines_to_config([" "]) + self.add_lines_to_config(["worker_app: test_worker_app"]) + + # Ensure that starting master process with worker config raises an exception + with self.assertRaises(ConfigError): + synapse.app.homeserver.setup(["-c", self.config_file]) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 765258c47a..69a4e9413b 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -46,15 +46,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): "was: %r" % (config.key.macaroon_secret_key,) ) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + assert config2 is not None self.assertTrue( - hasattr(config.key, "macaroon_secret_key"), + hasattr(config2.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) - if len(config.key.macaroon_secret_key) < 5: + if len(config2.key.macaroon_secret_key) < 5: self.fail( "Want macaroon secret key to be string of at least length 5," - "was: %r" % (config.key.macaroon_secret_key,) + "was: %r" % (config2.key.macaroon_secret_key,) ) def test_load_succeeds_if_macaroon_secret_key_missing(self): @@ -62,6 +63,9 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + assert config1 is not None + assert config2 is not None + assert config3 is not None self.assertEqual( config1.key.macaroon_secret_key, config2.key.macaroon_secret_key ) @@ -78,14 +82,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.registration.enable_registration) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) - self.assertFalse(config.registration.enable_registration) + config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) + assert config2 is not None + self.assertFalse(config2.registration.enable_registration) # Check that either config value is clobbered by the command line. - config = HomeServerConfig.load_or_generate_config( + config3 = HomeServerConfig.load_or_generate_config( "", ["-c", self.config_file, "--enable-registration"] ) - self.assertTrue(config.registration.enable_registration) + assert config3 is not None + self.assertTrue(config3.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") @@ -94,3 +100,12 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase): # The default Metrics Flags are off by default. config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.metrics.metrics_flags.known_servers) + + def test_depreciated_identity_server_flag_throws_error(self): + self.generate_config() + # Needed to ensure that actual key/value pair added below don't end up on a line with a comment + self.add_lines_to_config([" "]) + # Check that presence of "trust_identity_server_for_password" throws config error + self.add_lines_to_config(["trust_identity_server_for_password_resets: true"]) + with self.assertRaises(ConfigError): + HomeServerConfig.load_config("", ["-c", self.config_file]) diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py new file mode 100644 index 0000000000..17a84d20d8 --- /dev/null +++ b/tests/config/test_registration_config.py @@ -0,0 +1,78 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig + +from tests.unittest import TestCase +from tests.utils import default_config + + +class RegistrationConfigTestCase(TestCase): + def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): + """ + session_lifetime should logically be larger than, or at least as large as, + all the different token lifetimes. + Test that the user is faced with configuration errors if they make it + smaller, as that configuration doesn't make sense. + """ + config_dict = default_config("test") + + # First test all the error conditions + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refresh_token_lifetime": "31m", + **config_dict, + } + ) + + # Then test all the fine conditions + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} + ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index cbecc1c20f..17a9fb63a1 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -1,4 +1,4 @@ -# Copyright 2017 New Vector Ltd +# Copyright 2017-2021 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import signedjson.sign from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key +from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError @@ -40,7 +41,7 @@ from synapse.storage.keys import FetchKeyResult from tests import unittest from tests.test_utils import make_awaitable -from tests.unittest import logcontext_clean +from tests.unittest import logcontext_clean, override_config class MockPerspectiveServer: @@ -197,7 +198,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # self.assertFalse(d.called) self.get_success(d) - def test_verify_for_server_locally(self): + def test_verify_for_local_server(self): """Ensure that locally signed JSON can be verified without fetching keys over federation """ @@ -209,6 +210,56 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = kr.verify_json_for_server(self.hs.hostname, json1, 0) self.get_success(d) + OLD_KEY = signedjson.key.generate_signing_key("old") + + @override_config( + { + "old_signing_keys": { + f"{OLD_KEY.alg}:{OLD_KEY.version}": { + "key": encode_verify_key_base64(OLD_KEY.verify_key), + "expired_ts": 1000, + } + } + } + ) + def test_verify_for_local_server_old_key(self): + """Can also use keys in old_signing_keys for verification""" + json1 = {} + signedjson.sign.sign_json(json1, self.hs.hostname, self.OLD_KEY) + + kr = keyring.Keyring(self.hs) + d = kr.verify_json_for_server(self.hs.hostname, json1, 0) + self.get_success(d) + + def test_verify_for_local_server_unknown_key(self): + """Local keys that we no longer have should be fetched via the fetcher""" + + # the key we'll sign things with (nb, not known to the Keyring) + key2 = signedjson.key.generate_signing_key("2") + + # set up a mock fetcher which will return the key + async def get_keys( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + self.assertEqual(server_name, self.hs.hostname) + self.assertEqual(key_ids, [get_key_id(key2)]) + + return {get_key_id(key2): FetchKeyResult(get_verify_key(key2), 1200)} + + mock_fetcher = Mock() + mock_fetcher.get_keys = Mock(side_effect=get_keys) + kr = keyring.Keyring( + self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) + ) + + # sign the json + json1 = {} + signedjson.sign.sign_json(json1, self.hs.hostname, key2) + + # ... and check we can verify it. + d = kr.verify_json_for_server(self.hs.hostname, json1, 0) + self.get_success(d) + def test_verify_json_for_server_with_null_valid_until_ms(self): """Tests that we correctly handle key requests for keys we've stored with a null `ts_valid_until_ms` @@ -527,6 +578,76 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) + def test_get_multiple_keys_from_perspectives(self): + """Check that we can correctly request multiple keys for the same server""" + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = "server2" + + testkey1 = signedjson.key.generate_signing_key("ver1") + testverifykey1 = signedjson.key.get_verify_key(testkey1) + testverifykey1_id = "ed25519:ver1" + + testkey2 = signedjson.key.generate_signing_key("ver2") + testverifykey2 = signedjson.key.get_verify_key(testkey2) + testverifykey2_id = "ed25519:ver2" + + VALID_UNTIL_TS = 200 * 1000 + + response1 = self.build_perspectives_response( + SERVER_NAME, + testkey1, + VALID_UNTIL_TS, + ) + response2 = self.build_perspectives_response( + SERVER_NAME, + testkey2, + VALID_UNTIL_TS, + ) + + async def post_json(destination, path, data, **kwargs): + self.assertEqual(destination, self.mock_perspective_server.server_name) + self.assertEqual(path, "/_matrix/key/v2/query") + + # check that the request is for the expected keys + q = data["server_keys"] + + self.assertEqual( + list(q[SERVER_NAME].keys()), [testverifykey1_id, testverifykey2_id] + ) + return {"server_keys": [response1, response2]} + + self.http_client.post_json.side_effect = post_json + + # fire off two separate requests; they should get merged together into a + # single HTTP hit. + request1_d = defer.ensureDeferred( + fetcher.get_keys(SERVER_NAME, [testverifykey1_id], 0) + ) + request2_d = defer.ensureDeferred( + fetcher.get_keys(SERVER_NAME, [testverifykey2_id], 0) + ) + + keys1 = self.get_success(request1_d) + self.assertIn(testverifykey1_id, keys1) + k = keys1[testverifykey1_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey1) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") + + keys2 = self.get_success(request2_d) + self.assertIn(testverifykey2_id, keys2) + k = keys2[testverifykey2_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey2) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver2") + + # finally, ensure that only one request was sent + self.assertEqual(self.http_client.post_json.call_count, 1) + def test_get_perspectives_own_key(self): """Check that we can get the perspectives server's own keys diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b457dad6d2..b2376e2db9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) # expect signing key update edu - self.assertEqual(len(self.edus), 1) + self.assertEqual(len(self.edus), 2) + self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") # sign the devices @@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) -> None: """Check that the txn has an EDU with a signing key update.""" edus = txn["edus"] - self.assertEqual(len(edus), 1) + self.assertEqual(len(edus), 2) def generate_and_upload_device_signing_key( self, user_id: str, device_id: str diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py new file mode 100644 index 0000000000..a7031a55f2 --- /dev/null +++ b/tests/federation/transport/test_client.py @@ -0,0 +1,64 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.room_versions import RoomVersions +from synapse.federation.transport.client import SendJoinParser + +from tests.unittest import TestCase + + +class SendJoinParserTestCase(TestCase): + def test_two_writes(self) -> None: + """Test that the parser can sensibly deserialise an input given in two slices.""" + parser = SendJoinParser(RoomVersions.V1, True) + parent_event = { + "content": { + "see_room_version_spec": "The event format changes depending on the room version." + }, + "event_id": "$authparent", + "room_id": "!somewhere:example.org", + "type": "m.room.minimal_pdu", + } + state = { + "content": { + "see_room_version_spec": "The event format changes depending on the room version." + }, + "event_id": "$DoNotThinkAboutTheEvent", + "room_id": "!somewhere:example.org", + "type": "m.room.minimal_pdu", + } + response = [ + 200, + { + "auth_chain": [parent_event], + "origin": "matrix.org", + "state": [state], + }, + ] + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response[:100]) + parser.write(serialised_response[100:]) + + # Retrieve the parsed SendJoinResponse + parsed_response = parser.finish() + + # Sanity check the parsing gave us sensible data. + self.assertEqual(len(parsed_response.auth_events), 1, parsed_response) + self.assertEqual(len(parsed_response.state), 1, parsed_response) + self.assertEqual(parsed_response.event_dict, {}, parsed_response) + self.assertIsNone(parsed_response.event, parsed_response) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 12857053e7..03b8b8615c 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.get_failure( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ), ResourceLimitError, @@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase): # If not in monthly active cohort self.get_failure( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ), ResourceLimitError, @@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase): return_value=make_awaitable(self.clock.time_msec()) ) self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) # Ensure does not raise exception self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user1, device_id=None, valid_until_ms=None ) ) @@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index b625995d12..8705ff8943 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) def test_map_cas_user_to_existing_user(self): @@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_cas_user_to_invalid_localpart(self): @@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True + "@f=c3=b6=c3=b6:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index be008227df..0ea4e753e2 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2021 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - from unittest.mock import Mock import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.config.room_directory import RoomDirectoryConfig from synapse.rest.client import directory, login, room from synapse.types import RoomAlias, create_requester @@ -394,22 +393,15 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, reactor, clock, hs): - # We cheekily override the config to add custom alias creation rules - config = {} + def default_config(self): + config = super().default_config() + + # Add custom alias creation rules to the config. config["alias_creation_rules"] = [ {"user_id": "*", "alias": "#unofficial_*", "action": "allow"} ] - config["room_list_publication_rules"] = [] - rd_config = RoomDirectoryConfig() - rd_config.read_config(config) - - self.hs.config.roomdirectory.is_alias_creation_allowed = ( - rd_config.is_alias_creation_allowed - ) - - return hs + return config def test_denied(self): room_id = self.helper.create_room_as(self.user_id) @@ -417,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), + {"room_id": room_id}, ) self.assertEquals(403, channel.code, channel.result) @@ -427,14 +419,35 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), + {"room_id": room_id}, ) self.assertEquals(200, channel.code, channel.result) + def test_denied_during_creation(self): + """A room alias that is not allowed should be rejected during creation.""" + # Invalid room alias. + self.helper.create_room_as( + self.user_id, + expect_code=403, + extra_content={"room_alias_name": "foo"}, + ) -class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): - data = {"room_alias_name": "unofficial_test"} + def test_allowed_during_creation(self): + """A valid room alias should be allowed during creation.""" + room_id = self.helper.create_room_as( + self.user_id, + extra_content={"room_alias_name": "unofficial_test"}, + ) + channel = self.make_request( + "GET", + b"directory/room/%23unofficial_test%3Atest", + ) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(channel.json_body["room_id"], room_id) + + +class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -443,27 +456,30 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): ] hijack_auth = False - def prepare(self, reactor, clock, hs): - self.allowed_user_id = self.register_user("allowed", "pass") - self.allowed_access_token = self.login("allowed", "pass") + data = {"room_alias_name": "unofficial_test"} + allowed_localpart = "allowed" - self.denied_user_id = self.register_user("denied", "pass") - self.denied_access_token = self.login("denied", "pass") + def default_config(self): + config = super().default_config() - # This time we add custom room list publication rules - config = {} - config["alias_creation_rules"] = [] + # Add custom room list publication rules to the config. config["room_list_publication_rules"] = [ + { + "user_id": "@" + self.allowed_localpart + "*", + "alias": "#unofficial_*", + "action": "allow", + }, {"user_id": "*", "alias": "*", "action": "deny"}, - {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"}, ] - rd_config = RoomDirectoryConfig() - rd_config.read_config(config) + return config - self.hs.config.roomdirectory.is_publishing_room_allowed = ( - rd_config.is_publishing_room_allowed - ) + def prepare(self, reactor, clock, hs): + self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") + self.allowed_access_token = self.login(self.allowed_localpart, "pass") + + self.denied_user_id = self.register_user("denied", "pass") + self.denied_access_token = self.login("denied", "pass") return hs @@ -505,10 +521,23 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): self.allowed_user_id, tok=self.allowed_access_token, extra_content=self.data, - is_public=False, + is_public=True, expect_code=200, ) + def test_denied_publication_with_invalid_alias(self): + """ + Try to create a room, register an alias for it, and publish it, + as a user WITH permission to publish rooms. + """ + self.helper.create_room_as( + self.allowed_user_id, + tok=self.allowed_access_token, + extra_content={"room_alias_name": "foo"}, + is_public=True, + expect_code=403, + ) + def test_can_create_as_private_room_after_rejection(self): """ After failing to publish a room with an alias as a user without publish permission, diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 0c3b86fda9..f0723892e4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -162,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): local_user = "@boris:" + self.hs.hostname device_id = "xyz" fallback_key = {"alg1:k1": "key1"} + fallback_key2 = {"alg1:k2": "key2"} otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet @@ -213,6 +214,35 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) + # re-uploading the same fallback key should still result in no unused fallback + # keys + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key}, + ) + ) + + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + # uploading a new fallback key should result in an unused fallback key + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key2}, + ) + ) + + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, ["alg1"]) + # if the user uploads a one-time key, the next claim should fetch the # one-time key, and then go back to the fallback self.get_success( @@ -238,7 +268,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertEqual( res, - {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) def test_replace_master_key(self): diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a25c89bd5b..cfe3de5266 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) - # Return empty key set if JWKS are not used - self.provider._scopes = [] # not asking the openid scope - self.http_client.get_json.reset_mock() - jwks = self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_not_called() - self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=True + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=True, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._scopes = [] # do not ask the "openid" scope + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + } + self.provider._exchange_code = simple_async_mock(return_value=token) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=False + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() + # With an ID token, userinfo fetching and sid in the ID token + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + "id_token": "id_token", + } + id_token = { + "sid": "abcdefgh", + } + self.provider._parse_id_token = simple_async_mock(return_value=id_token) + self.provider._exchange_code = simple_async_mock(return_value=token) + auth_handler.complete_sso_login.reset_mock() + self.provider._fetch_userinfo.reset_mock() + self.get_success(self.handler.handle_oidc_callback(request)) + + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=id_token["sid"], + ) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_called_once_with(token) + self.render_error.assert_not_called() + # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url, {"phone": "1234567"}, new_user=True, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "oidc", ANY, ANY, None, new_user=True + "@test_user:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True + "@test_user_2:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False + "@TEST_USER_2:test", + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "oidc", ANY, ANY, None, new_user=True + "@test_user1:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={}) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 7dd4a5a367..08e9730d4d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -31,7 +31,10 @@ from tests.unittest import override_config # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] # a mock instance which the dummy auth providers delegate to, so we can see what's going # on diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index db691c4c1c..cd6f2c77ae 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - self.store.count_monthly_users = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.count_monthly_users = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - self.store.get_monthly_active_count = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.get_monthly_active_count = Mock( # type: ignore[assignment] return_value=make_awaitable(self.lots_of_users) ) self.get_failure( @@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.store.get_monthly_active_count = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.get_monthly_active_count = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index d3d0bf1ac5..e5a6a6c747 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -14,6 +14,8 @@ from typing import Any, Iterable, List, Optional, Tuple from unittest import mock +from twisted.internet.defer import ensureDeferred + from synapse.api.constants import ( EventContentFields, EventTypes, @@ -30,7 +32,7 @@ from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEnt from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from tests import unittest @@ -247,7 +249,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -261,7 +263,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): expected = [(self.space, [self.room]), (self.room, ())] self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) # If the space is made invite-only, it should no longer be viewable. @@ -272,7 +276,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + self.get_failure( + self.handler.get_room_hierarchy(create_requester(user2), self.space), + AuthError, + ) # If the space is made world-readable it should return a result. self.helper.send_state( @@ -284,7 +291,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) # Make it not world-readable again and confirm it results in an error. @@ -295,7 +304,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + self.get_failure( + self.handler.get_room_hierarchy(create_requester(user2), self.space), + AuthError, + ) # Join the space and results should be returned. self.helper.invite(self.space, targ=user2, tok=self.token) @@ -303,7 +315,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) # Attempting to view an unknown room returns the same error. @@ -312,10 +326,67 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): AuthError, ) self.get_failure( - self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + self.handler.get_room_hierarchy( + create_requester(user2), "#not-a-space:" + self.hs.hostname + ), AuthError, ) + def test_room_hierarchy_cache(self) -> None: + """In-flight room hierarchy requests are deduplicated.""" + # Run two `get_room_hierarchy` calls up until they block. + deferred1 = ensureDeferred( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + deferred2 = ensureDeferred( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + + # Complete the two calls. + result1 = self.get_success(deferred1) + result2 = self.get_success(deferred2) + + # Both `get_room_hierarchy` calls should return the same result. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_hierarchy(result1, expected) + self._assert_hierarchy(result2, expected) + self.assertIs(result1, result2) + + # A subsequent `get_room_hierarchy` call should not reuse the result. + result3 = self.get_success( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + self._assert_hierarchy(result3, expected) + self.assertIsNot(result1, result3) + + def test_room_hierarchy_cache_sharing(self) -> None: + """Room hierarchy responses for different users are not shared.""" + user2 = self.register_user("user2", "pass") + + # Make the room within the space invite-only. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + # Run two `get_room_hierarchy` calls for different users up until they block. + deferred1 = ensureDeferred( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + deferred2 = ensureDeferred( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) + + # Complete the two calls. + result1 = self.get_success(deferred1) + result2 = self.get_success(deferred2) + + # The `get_room_hierarchy` calls should return different results. + self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())]) + self._assert_hierarchy(result2, [(self.space, [self.room])]) + def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content ) -> str: @@ -410,7 +481,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(user2), self.space) + ) self._assert_hierarchy(result, expected) def test_complex_space(self): @@ -452,7 +525,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -467,7 +540,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, limit=7 + ) ) # The result should have the space and all of the links, plus some of the # rooms and a pagination token. @@ -479,7 +554,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Check the next page. result = self.get_success( self.handler.get_room_hierarchy( - self.user, self.space, limit=5, from_token=result["next_batch"] + create_requester(self.user), + self.space, + limit=5, + from_token=result["next_batch"], ) ) # The result should have the space and the room in it, along with a link @@ -499,20 +577,22 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): room_ids.append(self.room) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, limit=7 + ) ) self.assertIn("next_batch", result) # Changing the room ID, suggested-only, or max-depth causes an error. self.get_failure( self.handler.get_room_hierarchy( - self.user, self.room, from_token=result["next_batch"] + create_requester(self.user), self.room, from_token=result["next_batch"] ), SynapseError, ) self.get_failure( self.handler.get_room_hierarchy( - self.user, + create_requester(self.user), self.space, suggested_only=True, from_token=result["next_batch"], @@ -521,14 +601,19 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self.get_failure( self.handler.get_room_hierarchy( - self.user, self.space, max_depth=0, from_token=result["next_batch"] + create_requester(self.user), + self.space, + max_depth=0, + from_token=result["next_batch"], ), SynapseError, ) # An invalid token is ignored. self.get_failure( - self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, from_token="foo" + ), SynapseError, ) @@ -554,14 +639,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Test just the space itself. result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, max_depth=0 + ) ) expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] self._assert_hierarchy(result, expected) # A single additional layer. result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, max_depth=1 + ) ) expected += [ (rooms[0], ()), @@ -571,7 +660,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # A few layers. result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, max_depth=3 + ) ) expected += [ (rooms[1], ()), @@ -602,7 +693,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_rooms(result, expected) result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -684,7 +775,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -851,7 +942,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) @@ -909,7 +1000,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): new=summarize_remote_room_hierarchy, ): result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) + self.handler.get_room_hierarchy(create_requester(self.user), self.space) ) self._assert_hierarchy(result, expected) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8cfc184fef..50551aa6e3 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_saml_response_to_invalid_localpart(self): @@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "saml", request, "", None, new_user=True + "@test_user1:test", + "saml", + request, + "", + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index 1f9a2f9b1d..c8cc21cadd 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -36,8 +36,11 @@ class ServerNameTestCase(unittest.TestCase): "localhost:http", # non-numeric port "1234]", # smells like ipv6 literal but isn't "[1234", + "[1.2.3.4]", "underscore_.com", "percent%65.com", + "newline.com\n", + ".empty-label.com", "1234:5678:80", # too many colons ] for i in test_data: diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 90f800e564..f8cba7b645 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -128,6 +128,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.auth_handler = hs.get_auth_handler() + self.store = hs.get_datastore() def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated @@ -408,13 +409,7 @@ class EmailPusherTests(HomeserverTestCase): self.hs.get_datastore().db_pool.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success( - self.hs.get_datastore().db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.hs.get_datastore().db_pool.updates.do_next_background_update(100), - by=0.1, - ) + self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted pushers = self.get_success( diff --git a/tests/replication/_base.py b/tests/replication/_base.py index eac4664b41..cb02eddf07 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet.protocol import Protocol from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer -from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler @@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets: List[Callable[[HomeServer, JsonResource], None]] = [] - def setUp(self): super().setUp() diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 0a6e4795ee..596ba5a0c9 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -17,6 +17,7 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync +from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request @@ -193,7 +194,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # # Worker2's event stream position will not advance until we call # __aexit__ again. - actx = worker_hs2.get_datastore()._stream_id_gen.get_next() + worker_store2 = worker_hs2.get_datastore() + assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) + + actx = worker_store2._stream_id_gen.get_next() self.get_success(actx.__aenter__()) response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 192073c520..3adadcb46b 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import urllib.parse +from http import HTTPStatus from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -41,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase): def test_version_string(self): channel = self.make_request("GET", self.url, shorthand=False) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -70,11 +70,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) group_id = channel.json_body["group_id"] - self._check_group(group_id, expect_code=200) + self._check_group(group_id, expect_code=HTTPStatus.OK) # Invite/join another user @@ -82,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) url = "/groups/%s/self/accept_invite" % (group_id,) channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -103,10 +103,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - # Check group returns 404 - self._check_group(group_id, expect_code=404) + # Check group returns HTTPStatus.NOT_FOUND + self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND) # Check users don't think they're in the group self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -122,15 +122,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" channel = self.make_request("GET", b"/joined_groups", access_token=access_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) return channel.json_body["groups"] @@ -210,10 +208,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Should be quarantined self.assertEqual( - 404, - int(channel.code), + HTTPStatus.NOT_FOUND, + channel.code, msg=( - "Expected to receive a 404 on accessing quarantined media: %s" + "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s" % server_and_media_id ), ) @@ -232,8 +230,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - 403, - int(channel.result["code"]), + HTTPStatus.FORBIDDEN, + channel.code, msg="Expected forbidden on quarantining media as a non-admin", ) @@ -247,8 +245,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - 403, - int(channel.result["code"]), + HTTPStatus.FORBIDDEN, + channel.code, msg="Expected forbidden on quarantining media as a non-admin", ) @@ -279,7 +277,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Should be successful - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code) # Quarantine the media url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( @@ -292,7 +290,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) @@ -348,11 +346,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( - json.loads(channel.result["body"].decode("utf-8")), - {"num_quarantined": 2}, - "Expected 2 quarantined items", + channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) # Convert mxc URLs to server/media_id strings @@ -396,11 +392,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( - json.loads(channel.result["body"].decode("utf-8")), - {"num_quarantined": 2}, - "Expected 2 quarantined items", + channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) # Attempt to access each piece of media @@ -432,7 +426,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) channel = self.make_request("POST", url, access_token=admin_user_tok) self.pump(1.0) - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( @@ -444,11 +438,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( - json.loads(channel.result["body"].decode("utf-8")), - {"num_quarantined": 1}, - "Expected 1 quarantined item", + channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" ) # Attempt to access each piece of media, the first should fail, the @@ -467,10 +459,58 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Shouldn't be quarantined self.assertEqual( - 200, - int(channel.code), + HTTPStatus.OK, + channel.code, msg=( - "Expected to receive a 200 on accessing not-quarantined media: %s" + "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) + + +class PurgeHistoryTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" + self.url_status = "/_synapse/admin/v1/purge_history_status/" + + def test_purge_history(self): + """ + Simple test of purge history API. + Test only that is is possible to call, get status HTTPStatus.OK and purge_id. + """ + + channel = self.make_request( + "POST", + self.url, + content={"delete_local_events": True, "purge_up_to_ts": 0}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("purge_id", channel.json_body) + purge_id = channel.json_body["purge_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status + purge_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 78c48db552..4d152c0d66 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -11,10 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import Collection + +from parameterized import parameterized + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin +from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer +from synapse.storage.background_updates import BackgroundUpdater +from synapse.util import Clock from tests import unittest @@ -25,12 +34,66 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def _register_bg_update(self): + @parameterized.expand( + [ + ("GET", "/_synapse/admin/v1/background_updates/enabled"), + ("POST", "/_synapse/admin/v1/background_updates/enabled"), + ("GET", "/_synapse/admin/v1/background_updates/status"), + ("POST", "/_synapse/admin/v1/background_updates/start_job"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str) -> None: + """ + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + method, + url, + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self) -> None: + """ + If parameters are invalid, an error is returned. + """ + url = "/_synapse/admin/v1/background_updates/start_job" + + # empty content + channel = self.make_request( + "POST", + url, + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # job_name invalid + channel = self.make_request( + "POST", + url, + content={"job_name": "unknown"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def _register_bg_update(self) -> None: "Adds a bg update but doesn't start it" async def _fake_update(progress, batch_size) -> int: @@ -52,7 +115,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_status_empty(self): + def test_status_empty(self) -> None: """Test the status API works.""" channel = self.make_request( @@ -60,14 +123,14 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled, but none should be running. self.assertDictEqual( channel.json_body, {"current_updates": {}, "enabled": True} ) - def test_status_bg_update(self): + def test_status_bg_update(self) -> None: """Test the status API works with a background update.""" # Create a new background update @@ -75,14 +138,14 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self._register_bg_update() self.store.db_pool.updates.start_doing_background_updates() - self.reactor.pump([1.0, 1.0]) + self.reactor.pump([1.0, 1.0, 1.0]) channel = self.make_request( "GET", "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled, and one should be running. self.assertDictEqual( @@ -91,16 +154,18 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.1, + "average_items_per_ms": 0.001, "total_duration_ms": 1000.0, - "total_item_count": 100, + "total_item_count": ( + BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + ), } }, "enabled": True, }, ) - def test_enabled(self): + def test_enabled(self) -> None: """Test the enabled API works.""" # Create a new background update @@ -114,7 +179,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/enabled", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) # Disable the BG updates @@ -124,7 +189,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": False}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": False}) # Advance a bit and get the current status, note this will finish the in @@ -137,16 +202,18 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual( channel.json_body, { "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.1, + "average_items_per_ms": 0.001, "total_duration_ms": 1000.0, - "total_item_count": 100, + "total_item_count": ( + BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + ), } }, "enabled": False, @@ -162,7 +229,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # There should be no change from the previous /status response. self.assertDictEqual( @@ -171,9 +238,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.1, + "average_items_per_ms": 0.001, "total_duration_ms": 1000.0, - "total_item_count": 100, + "total_item_count": ( + BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + ), } }, "enabled": False, @@ -188,7 +257,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": True}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) @@ -199,7 +268,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled and making progress. self.assertDictEqual( @@ -208,11 +277,92 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "current_updates": { "master": { "name": "test_update", - "average_items_per_ms": 0.1, + "average_items_per_ms": 0.001, "total_duration_ms": 2000.0, - "total_item_count": 200, + "total_item_count": ( + 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE + ), } }, "enabled": True, }, ) + + @parameterized.expand( + [ + ("populate_stats_process_rooms", ["populate_stats_process_rooms"]), + ( + "regenerate_directory", + [ + "populate_user_directory_createtables", + "populate_user_directory_process_rooms", + "populate_user_directory_process_users", + "populate_user_directory_cleanup", + ], + ), + ] + ) + def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None: + """ + Test that background updates add to database and be processed. + + Args: + job_name: name of the job to call with API + updates: collection of background updates to be started + """ + + # no background update is waiting + self.assertTrue( + self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/background_updates/start_job", + content={"job_name": job_name}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # test that each background update is waiting now + for update in updates: + self.assertFalse( + self.get_success( + self.store.db_pool.updates.has_completed_background_update(update) + ) + ) + + self.wait_for_background_updates() + + # background updates are done + self.assertTrue( + self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ) + ) + + def test_start_backround_job_twice(self) -> None: + """Test that add a background update twice return an error.""" + + # add job to database + self.get_success( + self.store.db_pool.simple_insert( + table="background_updates", + values={ + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/background_updates/start_job", + content={"job_name": "populate_stats_process_rooms"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index a3679be205..f7080bda87 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import urllib.parse +from http import HTTPStatus from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -30,7 +34,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -47,17 +51,21 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_no_auth(self, method: str): + def test_no_auth(self, method: str) -> None: """ Try to get a device of an user without authentication. """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -67,13 +75,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_does_not_exist(self, method: str): + def test_user_does_not_exist(self, method: str) -> None: """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = ( "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" @@ -86,13 +98,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_is_not_local(self, method: str): + def test_user_is_not_local(self, method: str) -> None: """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = ( "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" @@ -105,12 +117,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_device(self): + def test_unknown_device(self) -> None: """ - Tests that a lookup for a device that does not exist returns either 404 or 200. + Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -122,7 +134,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) channel = self.make_request( @@ -131,7 +143,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) channel = self.make_request( "DELETE", @@ -139,10 +151,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # Delete unknown device returns status 200 - self.assertEqual(200, channel.code, msg=channel.json_body) + # Delete unknown device returns status HTTPStatus.OK + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_update_device_too_long_display_name(self): + def test_update_device_too_long_display_name(self) -> None: """ Update a device with a display name that is invalid (too long). """ @@ -167,7 +179,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content=update, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. @@ -177,12 +189,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_no_display_name(self): + def test_update_no_display_name(self) -> None: """ - Tests that a update for a device without JSON returns a 200 + Tests that a update for a device without JSON returns a HTTPStatus.OK """ # Set iniital display name. update = {"display_name": "new display"} @@ -198,7 +210,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Ensure the display name was not updated. channel = self.make_request( @@ -207,10 +219,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_display_name(self): + def test_update_display_name(self) -> None: """ Tests a normal successful update of display name """ @@ -222,7 +234,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content={"display_name": "new displayname"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check new display_name channel = self.make_request( @@ -231,10 +243,10 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) - def test_get_device(self): + def test_get_device(self) -> None: """ Tests that a normal lookup for a device is successfully """ @@ -244,7 +256,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) # Check that all fields are available self.assertIn("user_id", channel.json_body) @@ -253,7 +265,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertIn("last_seen_ip", channel.json_body) self.assertIn("last_seen_ts", channel.json_body) - def test_delete_device(self): + def test_delete_device(self) -> None: """ Tests that a remove of a device is successfully """ @@ -269,7 +281,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Ensure that the number of devices is decreased res = self.get_success(self.handler.get_devices_by_user(self.other_user)) @@ -283,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -293,16 +305,20 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list devices of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -314,12 +330,16 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" channel = self.make_request( @@ -328,12 +348,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" @@ -343,10 +363,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_user_has_no_devices(self): + def test_user_has_no_devices(self) -> None: """ Tests that a normal lookup for devices is successfully if user has no devices @@ -359,11 +379,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) - def test_get_devices(self): + def test_get_devices(self) -> None: """ Tests that a normal lookup for devices is successfully """ @@ -379,7 +399,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) @@ -399,7 +419,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -411,16 +431,20 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete devices of an user without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -432,12 +456,16 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" channel = self.make_request( @@ -446,12 +474,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" @@ -461,12 +489,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_devices(self): + def test_unknown_devices(self) -> None: """ - Tests that a remove of a device that does not exist returns 200. + Tests that a remove of a device that does not exist returns HTTPStatus.OK. """ channel = self.make_request( "POST", @@ -475,10 +503,10 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": ["unknown_device1", "unknown_device2"]}, ) - # Delete unknown devices returns status 200 - self.assertEqual(200, channel.code, msg=channel.json_body) + # Delete unknown devices returns status HTTPStatus.OK + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_delete_devices(self): + def test_delete_devices(self) -> None: """ Tests that a remove of devices is successfully """ @@ -505,7 +533,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": device_ids}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) res = self.get_success(self.handler.get_devices_by_user(self.other_user)) self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index e9ef89731f..4f89f8b534 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import List -import json +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, report_event, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -29,7 +34,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -70,18 +75,22 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get an event report without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( @@ -90,10 +99,14 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self): + def test_default_success(self) -> None: """ Testing list of reported events """ @@ -104,13 +117,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of reported events with limit """ @@ -121,13 +134,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["event_reports"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of reported events with a defined starting point (from) """ @@ -138,13 +151,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of reported events with a defined starting point and limit """ @@ -155,13 +168,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["event_reports"]), 10) self._check_fields(channel.json_body["event_reports"]) - def test_filter_room(self): + def test_filter_room(self) -> None: """ Testing list of reported events with a filter of room """ @@ -172,7 +185,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -181,7 +194,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["room_id"], self.room_id1) - def test_filter_user(self): + def test_filter_user(self) -> None: """ Testing list of reported events with a filter of user """ @@ -192,7 +205,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -201,7 +214,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["user_id"], self.other_user) - def test_filter_user_and_room(self): + def test_filter_user_and_room(self) -> None: """ Testing list of reported events with a filter of user and room """ @@ -212,7 +225,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -222,7 +235,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["room_id"], self.room_id1) - def test_valid_search_order(self): + def test_valid_search_order(self) -> None: """ Testing search order. Order by timestamps. """ @@ -234,7 +247,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -252,7 +265,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -263,9 +276,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) report += 1 - def test_invalid_search_order(self): + def test_invalid_search_order(self) -> None: """ - Testing that a invalid search order returns a 400 + Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( @@ -274,13 +287,17 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) - def test_limit_is_negative(self): + def test_limit_is_negative(self) -> None: """ - Testing that a negative limit parameter returns a 400 + Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( @@ -289,12 +306,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self): + def test_from_is_negative(self) -> None: """ - Testing that a negative from parameter returns a 400 + Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( @@ -303,10 +324,14 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -319,7 +344,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -332,7 +357,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -345,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -359,12 +384,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) - def _create_event_and_report(self, room_id, user_tok): + def _create_event_and_report(self, room_id: str, user_tok: str) -> None: """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -372,12 +397,14 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - json.dumps({"score": -100, "reason": "this makes me sad"}), + {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _create_event_and_report_without_parameters(self, room_id, user_tok): + def _create_event_and_report_without_parameters( + self, room_id: str, user_tok: str + ) -> None: """Create and report an event, but omit reason and score""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -385,12 +412,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - json.dumps({}), + {}, access_token=user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _check_fields(self, content): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) @@ -413,7 +440,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -433,18 +460,22 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): # first created event report gets `id`=2 self.url = "/_synapse/admin/v1/event_reports/2" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get event report without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( @@ -453,10 +484,14 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self): + def test_default_success(self) -> None: """ Testing get a reported event """ @@ -467,12 +502,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) - def test_invalid_report_id(self): + def test_invalid_report_id(self) -> None: """ - Testing that an invalid `report_id` returns a 400. + Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. """ # `report_id` is negative @@ -482,7 +517,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -496,7 +535,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -510,16 +553,20 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", channel.json_body["error"], ) - def test_report_id_not_found(self): + def test_report_id_not_found(self) -> None: """ - Testing that a not existing `report_id` returns a 404. + Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. """ channel = self.make_request( @@ -528,11 +575,15 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) - def _create_event_and_report(self, room_id, user_tok): + def _create_event_and_report(self, room_id: str, user_tok: str) -> None: """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -540,12 +591,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - json.dumps({"score": -100, "reason": "this makes me sad"}), + {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _check_fields(self, content): + def _check_fields(self, content: JsonDict) -> None: """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py new file mode 100644 index 0000000000..5188499ef2 --- /dev/null +++ b/tests/rest/admin/test_federation.py @@ -0,0 +1,456 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import List, Optional + +from parameterized import parameterized + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonDict + +from tests import unittest + + +class FederationTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.store = hs.get_datastore() + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.url = "/_synapse/admin/v1/federation/destinations" + + @parameterized.expand( + [ + ("/_synapse/admin/v1/federation/destinations",), + ("/_synapse/admin/v1/federation/destinations/dummy",), + ] + ) + def test_requester_is_no_admin(self, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + "GET", + url, + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid destination + channel = self.make_request( + "GET", + self.url + "/dummy", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of destinations with limit + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["destinations"]) + + def test_from(self): + """ + Testing list of destinations with a defined starting point (from) + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["destinations"]) + + def test_limit_and_from(self): + """ + Testing list of destinations with a defined starting point and limit + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["destinations"]), 10) + self._check_fields(channel.json_body["destinations"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), number_destinations) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), number_destinations) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_list_all_destinations(self): + """ + List all destinations. + """ + number_destinations = 5 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(number_destinations, len(channel.json_body["destinations"])) + self.assertEqual(number_destinations, channel.json_body["total"]) + + # Check that all fields are available + self._check_fields(channel.json_body["destinations"]) + + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + def _order_test( + expected_destination_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of destinations in a certain order. + Assert that order is what we expect + + Args: + expected_destination_list: The list of user_id in the order + we expect to get back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = f"{self.url}?" + if order_by is not None: + url += f"order_by={order_by}&" + if dir is not None and dir in ("b", "f"): + url += f"dir={dir}" + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_destination_list)) + + returned_order = [ + row["destination"] for row in channel.json_body["destinations"] + ] + self.assertEqual(expected_destination_list, returned_order) + self._check_fields(channel.json_body["destinations"]) + + # create destinations + dest = [ + ("sub-a.example.com", 100, 300, 200, 300), + ("sub-b.example.com", 200, 200, 100, 100), + ("sub-c.example.com", 300, 100, 300, 200), + ] + for ( + destination, + failure_ts, + retry_last_ts, + retry_interval, + last_successful_stream_ordering, + ) in dest: + self.get_success( + self.store.set_destination_retry_timings( + destination, failure_ts, retry_last_ts, retry_interval + ) + ) + self.get_success( + self.store.set_destination_last_successful_stream_ordering( + destination, last_successful_stream_ordering + ) + ) + + # order by default (destination) + _order_test([dest[0][0], dest[1][0], dest[2][0]], None) + _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b") + + # order by destination + _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b") + + # order by failure_ts + _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b") + + # order by retry_last_ts + _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b") + + # order by retry_interval + _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval") + _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f") + _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b") + + # order by last_successful_stream_ordering + _order_test( + [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering" + ) + _order_test( + [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f" + ) + _order_test( + [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b" + ) + + def test_search_term(self): + """Test that searching for a destination works correctly""" + + def _search_test( + expected_destination: Optional[str], + search_term: str, + ): + """Search for a destination and check that the returned destinationis a match + + Args: + expected_destination: The room_id expected to be returned by the API. + Set to None to expect zero results for the search + search_term: The term to search for room names with + """ + url = f"{self.url}?destination={search_term}" + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Check that destinations were returned + self.assertTrue("destinations" in channel.json_body) + self._check_fields(channel.json_body["destinations"]) + destinations = channel.json_body["destinations"] + + # Check that the expected number of destinations were returned + expected_destination_count = 1 if expected_destination else 0 + self.assertEqual(len(destinations), expected_destination_count) + self.assertEqual(channel.json_body["total"], expected_destination_count) + + if expected_destination: + # Check that the first returned destination is correct + self.assertEqual(expected_destination, destinations[0]["destination"]) + + number_destinations = 3 + self._create_destinations(number_destinations) + + # Test searching + _search_test("sub0.example.com", "0") + _search_test("sub0.example.com", "sub0") + + _search_test("sub1.example.com", "1") + _search_test("sub1.example.com", "1.") + + # Test case insensitive + _search_test("sub0.example.com", "SUB0") + + _search_test(None, "foo") + _search_test(None, "bar") + + def test_get_single_destination(self): + """ + Get one specific destinations. + """ + self._create_destinations(5) + + channel = self.make_request( + "GET", + self.url + "/sub0.example.com", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("sub0.example.com", channel.json_body["destination"]) + + # Check that all fields are available + # convert channel.json_body into a List + self._check_fields([channel.json_body]) + + def _create_destinations(self, number_destinations: int): + """Create a number of destinations + + Args: + number_destinations: Number of destinations to be created + """ + for i in range(0, number_destinations): + dest = f"sub{i}.example.com" + self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) + self.get_success( + self.store.set_destination_last_successful_stream_ordering(dest, 100) + ) + + def _check_fields(self, content: List[JsonDict]): + """Checks that the expected destination attributes are present in content + + Args: + content: List that is checked for content + """ + for c in content: + self.assertIn("destination", c) + self.assertIn("retry_last_ts", c) + self.assertIn("retry_interval", c) + self.assertIn("failure_ts", c) + self.assertIn("last_successful_stream_ordering", c) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index db0e78c039..81e578fd26 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -12,16 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import json import os +from http import HTTPStatus from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -39,7 +42,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -48,7 +51,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -56,10 +59,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("DELETE", url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -74,12 +81,16 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_does_not_exist(self): + def test_media_does_not_exist(self) -> None: """ - Tests that a lookup for a media that does not exist returns a 404 + Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") @@ -89,12 +100,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_media_is_not_local(self): + def test_media_is_not_local(self) -> None: """ - Tests that a lookup for a media that is not a local returns a 400 + Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") @@ -104,10 +115,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_delete_media(self): + def test_delete_media(self) -> None: """ Tests that delete a media is successfully """ @@ -117,7 +128,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -137,10 +151,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Should be successful self.assertEqual( - 200, + HTTPStatus.OK, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" % server_and_media_id + "Expected to receive a HTTPStatus.OK on accessing media: %s" + % server_and_media_id ), ) @@ -157,7 +172,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -174,10 +189,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - 404, + HTTPStatus.NOT_FOUND, channel.code, msg=( - "Expected to receive a 404 on accessing deleted media: %s" + "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" % server_and_media_id ), ) @@ -196,7 +211,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -209,17 +224,21 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Move clock up to somewhat realistic time self.reactor.advance(1000000000) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -232,12 +251,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_is_not_local(self): + def test_media_is_not_local(self) -> None: """ - Tests that a lookup for media that is not local returns a 400 + Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" @@ -247,10 +270,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_missing_parameter(self): + def test_missing_parameter(self) -> None: """ If the parameter `before_ts` is missing, an error is returned. """ @@ -260,13 +283,17 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -276,7 +303,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -289,7 +320,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " @@ -303,7 +338,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter size_gt must be a string representing a positive integer.", @@ -316,14 +355,18 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], ) - def test_delete_media_never_accessed(self): + def test_delete_media_never_accessed(self) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` @@ -345,7 +388,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -354,7 +397,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_date(self): + def test_keep_media_by_date(self) -> None: """ Tests that media is not deleted if it is newer than `before_ts` """ @@ -370,7 +413,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -382,7 +425,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -391,7 +434,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_size(self): + def test_keep_media_by_size(self) -> None: """ Tests that media is not deleted if its size is smaller than or equal to `size_gt` @@ -406,7 +449,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -417,7 +460,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -426,7 +469,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_user_avatar(self): + def test_keep_media_by_user_avatar(self) -> None: """ Tests that we do not delete media if is used as a user avatar Tests parameter `keep_profiles` @@ -439,10 +482,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.admin_user,), - content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), + content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -450,7 +493,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -461,7 +504,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -470,7 +513,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_room_avatar(self): + def test_keep_media_by_room_avatar(self) -> None: """ Tests that we do not delete media if it is used as a room avatar Tests parameter `keep_profiles` @@ -484,10 +527,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.avatar" % (room_id,), - content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), + content={"url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -495,7 +538,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -506,7 +549,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -515,7 +558,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def _create_media(self): + def _create_media(self) -> str: """ Create a media and return media_id and server_and_media_id """ @@ -523,7 +566,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -534,7 +580,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): return server_and_media_id - def _access_media(self, server_and_media_id, expect_success=True): + def _access_media(self, server_and_media_id, expect_success=True) -> None: """ Try to access a media and check the result """ @@ -554,10 +600,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): if expect_success: self.assertEqual( - 200, + HTTPStatus.OK, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" + "Expected to receive a HTTPStatus.OK on accessing media: %s" % server_and_media_id ), ) @@ -565,10 +611,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) else: self.assertEqual( - 404, + HTTPStatus.NOT_FOUND, channel.code, msg=( - "Expected to receive a 404 on accessing deleted media: %s" + "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" % (server_and_media_id) ), ) @@ -584,7 +630,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() self.server_name = hs.hostname @@ -597,7 +643,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -606,7 +655,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s/%s" @parameterized.expand(["quarantine", "unquarantine"]) - def test_no_auth(self, action: str): + def test_no_auth(self, action: str) -> None: """ Try to protect media without authentication. """ @@ -617,11 +666,15 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): b"{}", ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) - def test_requester_is_no_admin(self, action: str): + def test_requester_is_no_admin(self, action: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -634,10 +687,14 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_quarantine_media(self): + def test_quarantine_media(self) -> None: """ Tests that quarantining and remove from quarantine a media is successfully """ @@ -652,7 +709,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -665,13 +722,13 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) self.assertFalse(media_info["quarantined_by"]) - def test_quarantine_protected_media(self): + def test_quarantine_protected_media(self) -> None: """ Tests that quarantining from protected media fails """ @@ -690,7 +747,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) # verify that is not in quarantine @@ -706,7 +763,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() @@ -718,7 +775,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -727,18 +787,22 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s" @parameterized.expand(["protect", "unprotect"]) - def test_no_auth(self, action: str): + def test_no_auth(self, action: str) -> None: """ Try to protect media without authentication. """ channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) - def test_requester_is_no_admin(self, action: str): + def test_requester_is_no_admin(self, action: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -751,10 +815,14 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_protect_media(self): + def test_protect_media(self) -> None: """ Tests that protect and unprotect a media is successfully """ @@ -769,7 +837,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -782,7 +850,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -799,7 +867,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -809,17 +877,21 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/purge_media_cache" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + def test_requester_is_not_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -832,10 +904,14 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -845,7 +921,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -858,7 +938,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 9bac423ae0..350a62dda6 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import random import string +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -28,7 +32,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -38,7 +42,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/registration_tokens" - def _new_token(self, **kwargs): + def _new_token(self, **kwargs) -> str: """Helper function to create a token.""" token = kwargs.get( "token", @@ -60,13 +64,17 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # CREATION - def test_create_no_auth(self): + def test_create_no_auth(self) -> None: """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_create_requester_not_admin(self): + def test_create_requester_not_admin(self) -> None: """Try to create a token while not an admin.""" channel = self.make_request( "POST", @@ -74,10 +82,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_create_using_defaults(self): + def test_create_using_defaults(self) -> None: """Create a token using all the defaults.""" channel = self.make_request( "POST", @@ -86,14 +98,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_specifying_fields(self): + def test_create_specifying_fields(self) -> None: """Create a token specifying the value of all fields.""" # As many of the allowed characters as possible with length <= 64 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" @@ -110,14 +122,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_with_null_value(self): + def test_create_with_null_value(self) -> None: """Create a token specifying unlimited uses and no expiry.""" data = { "uses_allowed": None, @@ -131,14 +143,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_token_too_long(self): + def test_create_token_too_long(self) -> None: """Check token longer than 64 chars is invalid.""" data = {"token": "a" * 65} @@ -149,10 +161,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_invalid_chars(self): + def test_create_token_invalid_chars(self) -> None: """Check you can't create token with invalid characters.""" data = { "token": "abc/def", @@ -165,10 +181,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_already_exists(self): + def test_create_token_already_exists(self) -> None: """Check you can't create token that already exists.""" data = { "token": "abcd", @@ -180,7 +200,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) channel2 = self.make_request( "POST", @@ -188,10 +208,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_unable_to_generate_token(self): + def test_create_unable_to_generate_token(self) -> None: """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] @@ -220,9 +240,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(500, channel.code, msg=channel.json_body) - def test_create_uses_allowed(self): + def test_create_uses_allowed(self) -> None: """Check you can only create a token with good values for uses_allowed.""" # Should work with 0 (token is invalid from the start) channel = self.make_request( @@ -231,7 +251,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -241,7 +261,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -251,10 +275,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_expiry_time(self): + def test_create_expiry_time(self) -> None: """Check you can't create a token with an invalid expiry_time.""" # Should fail with a time in the past channel = self.make_request( @@ -263,7 +291,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -273,10 +305,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_length(self): + def test_create_length(self) -> None: """Check you can only generate a token with a valid length.""" # Should work with 64 channel = self.make_request( @@ -285,7 +321,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -295,7 +331,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -305,7 +345,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -315,7 +359,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -325,22 +373,30 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING - def test_update_no_auth(self): + def test_update_no_auth(self) -> None: """Try to update a token without authentication.""" channel = self.make_request( "PUT", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_update_requester_not_admin(self): + def test_update_requester_not_admin(self) -> None: """Try to update a token while not an admin.""" channel = self.make_request( "PUT", @@ -348,10 +404,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_update_non_existent(self): + def test_update_non_existent(self) -> None: """Try to update a token that doesn't exist.""" channel = self.make_request( "PUT", @@ -360,10 +420,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_update_uses_allowed(self): + def test_update_uses_allowed(self) -> None: """Test updating just uses_allowed.""" # Create new token using default values token = self._new_token() @@ -375,7 +439,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -386,7 +450,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -397,7 +461,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -408,7 +472,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -418,10 +486,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_expiry_time(self): + def test_update_expiry_time(self) -> None: """Test updating just expiry_time.""" # Create new token using default values token = self._new_token() @@ -434,7 +506,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -445,7 +517,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -457,7 +529,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -467,10 +543,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_both(self): + def test_update_both(self) -> None: """Test updating both uses_allowed and expiry_time.""" # Create new token using default values token = self._new_token() @@ -488,11 +568,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) - def test_update_invalid_type(self): + def test_update_invalid_type(self) -> None: """Test using invalid types doesn't work.""" # Create new token using default values token = self._new_token() @@ -509,22 +589,30 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING - def test_delete_no_auth(self): + def test_delete_no_auth(self) -> None: """Try to delete a token without authentication.""" channel = self.make_request( "DELETE", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_delete_requester_not_admin(self): + def test_delete_requester_not_admin(self) -> None: """Try to delete a token while not an admin.""" channel = self.make_request( "DELETE", @@ -532,10 +620,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_delete_non_existent(self): + def test_delete_non_existent(self) -> None: """Try to delete a token that doesn't exist.""" channel = self.make_request( "DELETE", @@ -544,10 +636,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_delete(self): + def test_delete(self) -> None: """Test deleting a token.""" # Create new token using default values token = self._new_token() @@ -559,21 +655,25 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # GETTING ONE - def test_get_no_auth(self): + def test_get_no_auth(self) -> None: """Try to get a token without authentication.""" channel = self.make_request( "GET", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_get_requester_not_admin(self): + def test_get_requester_not_admin(self) -> None: """Try to get a token while not an admin.""" channel = self.make_request( "GET", @@ -581,10 +681,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_get_non_existent(self): + def test_get_non_existent(self) -> None: """Try to get a token that doesn't exist.""" channel = self.make_request( "GET", @@ -593,10 +697,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_get(self): + def test_get(self) -> None: """Test getting a token.""" # Create new token using default values token = self._new_token() @@ -608,7 +716,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -617,13 +725,17 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # LISTING - def test_list_no_auth(self): + def test_list_no_auth(self) -> None: """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_list_requester_not_admin(self): + def test_list_requester_not_admin(self) -> None: """Try to list tokens while not an admin.""" channel = self.make_request( "GET", @@ -631,10 +743,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_list_all(self): + def test_list_all(self) -> None: """Test listing all tokens.""" # Create new token using default values token = self._new_token() @@ -646,7 +762,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -655,7 +771,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["completed"], 0) - def test_list_invalid_query_parameter(self): + def test_list_invalid_query_parameter(self) -> None: """Test with `valid` query parameter not `true` or `false`.""" channel = self.make_request( "GET", @@ -664,9 +780,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) - def _test_list_query_parameter(self, valid: str): + def _test_list_query_parameter(self, valid: str) -> None: """Helper used to test both valid=true and valid=false.""" # Create 2 valid and 2 invalid tokens. now = self.hs.get_clock().time_msec() @@ -696,17 +816,17 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_2["token"], tokens) - def test_list_valid(self): + def test_list_valid(self) -> None: """Test listing just valid tokens.""" self._test_list_query_parameter(valid="true") - def test_list_invalid(self): + def test_list_invalid(self) -> None: """Test listing just invalid tokens.""" self._test_list_query_parameter(valid="false") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 11ec54c82e..22f9aa6234 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import json import urllib.parse from http import HTTPStatus from typing import List, Optional @@ -20,10 +18,15 @@ from unittest.mock import Mock from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes +from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -39,7 +42,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -65,49 +68,48 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( "DELETE", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error 404. + Check that unknown rooms/server return 200 """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_room_is_not_valid(self): """ - Check that invalid room names, return an error 400. + Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom" channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -117,16 +119,15 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Tests that the user ID must be from local server but it does not have to exist. """ - body = json.dumps({"new_room_user_id": "@unknown:test"}) channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content={"new_room_user_id": "@unknown:test"}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -136,16 +137,15 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Check that only local users can create new room to move members. """ - body = json.dumps({"new_room_user_id": "@not:exist.bla"}) channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content={"new_room_user_id": "@not:exist.bla"}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -155,32 +155,30 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ If parameter `block` is not boolean, return an error """ - body = json.dumps({"block": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content={"block": "NotBool"}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): """ If parameter `purge` is not boolean, return an error """ - body = json.dumps({"purge": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content={"purge": "NotBool"}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -197,16 +195,14 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True, "purge": True}) - channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content={"block": True, "purge": True}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -230,16 +226,14 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": False, "purge": True}) - channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content={"block": False, "purge": True}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -264,16 +258,14 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True, "purge": False}) - channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content={"block": True, "purge": False}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -304,9 +296,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) # The room is now blocked. - self.assertEqual( - HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._is_blocked(room_id) def test_shutdown_room_consent(self): @@ -326,7 +316,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + self.room_id, + body="foo", + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Test that room is not purged @@ -340,11 +333,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({"new_room_user_id": self.admin_user}), + {"new_room_user_id": self.admin_user}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -370,10 +363,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), + {"history_visibility": "world_readable"}, access_token=self.other_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -386,11 +379,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({"new_room_user_id": self.admin_user}), + {"new_room_user_id": self.admin_user}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -405,7 +398,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=403) + self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) def _is_blocked(self, room_id, expect=True): """Assert that the room is blocked or not""" @@ -446,18 +439,629 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + url = "events?timeout=0&room_id=" + room_id + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + +class DeleteRoomV2TestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.consent.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v2/rooms/{self.room_id}" + self.url_status_by_room_id = ( + f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status" + ) + self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/" + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + """ + + channel = self.make_request( + method, + url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return 200 + + This is important, as it allows incomplete vestiges of rooms to be cleared up + even if the create event/etc is missing. + """ + room_id = "!unknown:test" + channel = self.make_request( + "DELETE", + f"/_synapse/admin/v2/rooms/{room_id}", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ] + ) + def test_room_is_not_valid(self, method: str, url: str): + """ + Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + """ + + channel = self.make_request( + method, + url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@unknown:test"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@not:exist.bla"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "User must be our own: @not:exist.bla", + channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"purge": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_delete_expired_status(self): + """Test that the task status is removed after expiration.""" + + # first task, do not purge, that we can create a second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id1 = channel.json_body["delete_id"] + + # go ahead + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + # second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id2 = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(2, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual("complete", channel.json_body["results"][1]["status"]) + self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"]) + + # get status after more than clearing time for first task + # second task is not cleared + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) + + # get status after more than clearing time for all tasks + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_delete_same_room_twice(self): + """Test that the call for delete a room at second time gives an exception.""" + + body = {"new_room_user_id": self.admin_user} + + # first call to delete room + # and do not wait for finish the task + first_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + await_result=False, + ) + + # second call to delete room + second_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + ) + self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body + ) + self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) + self.assertEqual( + f"History purge already in progress for {self.room_id}", + second_channel.json_body["error"], + ) + + # get result of first call + first_channel.await_result() + self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertIn("delete_id", first_channel.json_body) + + # check status after finish the task + self._test_result( + first_channel.json_body["delete_id"], + self.other_user, + expect_new_room=True, + ) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": False, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, + body="foo", + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + channel = self.make_request( + "PUT", + url.encode("ascii"), + content={"history_visibility": "world_readable"}, + access_token=self.other_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id: str) -> None: + """Assert there is now no longer anyone in the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id: str, user_id: str) -> None: + """Test that user is member of the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id: str) -> None: + """Test that the following tables have been purged of all rows related to the room.""" + for table in PURGE_TABLES: + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") + + def _assert_peek(self, room_id: str, expect_code: int) -> None: + """Assert that the admin user can (or cannot) peek into the room.""" + + url = f"rooms/{room_id}/initialSync" + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) url = "events?timeout=0&room_id=" + room_id channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + def _test_result( + self, + delete_id: str, + kicked_user: str, + expect_new_room: bool = False, + ) -> None: + """ + Test that the result is the expected. + Uses both APIs (status by room_id and delete_id) + + Args: + delete_id: id of this purge + kicked_user: a user_id which is kicked from the room + expect_new_room: if we expect that a new room was created + """ + + # get information by room_id + channel_room_id = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body + ) + self.assertEqual(1, len(channel_room_id.json_body["results"])) + self.assertEqual( + delete_id, channel_room_id.json_body["results"][0]["delete_id"] ) + # get information by delete_id + channel_delete_id = self.make_request( + "GET", + self.url_status_by_delete_id + delete_id, + access_token=self.admin_user_tok, + ) + self.assertEqual( + HTTPStatus.OK, + channel_delete_id.code, + msg=channel_delete_id.json_body, + ) + + # test values that are the same in both responses + for content in [ + channel_room_id.json_body["results"][0], + channel_delete_id.json_body, + ]: + self.assertEqual("complete", content["status"]) + self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0]) + self.assertIn("failed_to_kick_users", content["shutdown_room"]) + self.assertIn("local_aliases", content["shutdown_room"]) + self.assertNotIn("error", content) + + if expect_new_room: + self.assertIsNotNone(content["shutdown_room"]["new_room_id"]) + else: + self.assertIsNone(content["shutdown_room"]["new_room_id"]) + class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" @@ -469,12 +1073,12 @@ class RoomTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def test_list_rooms(self): + def test_list_rooms(self) -> None: """Test that we can list rooms""" # Create 3 test rooms total_rooms = 3 @@ -494,7 +1098,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -538,7 +1142,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # We shouldn't receive a next token here as there's no further rooms to show self.assertNotIn("next_batch", channel.json_body) - def test_list_rooms_pagination(self): + def test_list_rooms_pagination(self) -> None: """Test that we can get a full list of rooms through pagination""" # Create 5 test rooms total_rooms = 5 @@ -578,9 +1182,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -620,9 +1222,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_correct_room_attributes(self): + def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -643,7 +1245,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -675,7 +1277,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -703,7 +1305,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) - def test_room_list_sort_order(self): + def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, reversing the order, etc. """ @@ -712,7 +1314,7 @@ class RoomTestCase(unittest.HomeserverTestCase): order_type: str, expected_room_list: List[str], reverse: bool = False, - ): + ) -> None: """Request the list of rooms in a certain order. Assert that order is what we expect @@ -730,7 +1332,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -841,7 +1443,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - def test_search_term(self): + def test_search_term(self) -> None: """Test that searching for a room works correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -869,8 +1471,8 @@ class RoomTestCase(unittest.HomeserverTestCase): def _search_test( expected_room_id: Optional[str], search_term: str, - expected_http_code: int = 200, - ): + expected_http_code: int = HTTPStatus.OK, + ) -> None: """Search for a room and check that the returned room's id is a match Args: @@ -887,7 +1489,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != 200: + if expected_http_code != HTTPStatus.OK: return # Check that rooms were returned @@ -930,7 +1532,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) + _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST) # Test that the whole room id returns the room _search_test(room_id_1, room_id_1) @@ -944,7 +1546,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Test search local part of alias _search_test(room_id_1, "alias1") - def test_search_term_non_ascii(self): + def test_search_term_non_ascii(self) -> None: """Test that searching for a room with non-ASCII characters works correctly""" # Create test room @@ -967,11 +1569,11 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) - def test_single_room(self): + def test_single_room(self) -> None: """Test that a single room can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1000,7 +1602,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) @@ -1022,7 +1624,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) - def test_single_room_devices(self): + def test_single_room_devices(self) -> None: """Test that `joined_local_devices` can be requested correctly""" room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1032,7 +1634,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) # Have another user join the room @@ -1046,7 +1648,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) # leave room @@ -1058,10 +1660,10 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) - def test_room_members(self): + def test_room_members(self) -> None: """Test that room members can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1089,7 +1691,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] @@ -1102,14 +1704,14 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] ) self.assertEqual(channel.json_body["total"], 3) - def test_room_state(self): + def test_room_state(self) -> None: """Test that room state can be requested correctly""" # Create two test rooms room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1120,13 +1722,15 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) # testing that the state events match is painful and not done here. We assume that # the create_room already does the right thing, so no need to verify that we got # the state events it created. - def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str): + def _set_canonical_alias( + self, room_id: str, test_alias: str, admin_user_tok: str + ) -> None: # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) channel = self.make_request( @@ -1135,7 +1739,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1161,7 +1765,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1176,124 +1780,117 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): ) self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If a parameter is missing, return an error """ - body = json.dumps({"unknown_parameter": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"unknown_parameter": "@unknown:test"}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_local_user_does_not_exist(self): + def test_local_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ - body = json.dumps({"user_id": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"user_id": "@unknown:test"}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_remote_user(self): + def test_remote_user(self) -> None: """ Check that only local user can join rooms. """ - body = json.dumps({"user_id": "@not:exist.bla"}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"user_id": "@not:exist.bla"}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], ) - def test_room_does_not_exist(self): + def test_room_does_not_exist(self) -> None: """ - Check that unknown rooms/server return error 404. + Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ - body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) - def test_room_is_not_valid(self): + def test_room_is_not_valid(self) -> None: """ - Check that invalid room names, return an error 400. + Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ - body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], ) - def test_join_public_room(self): + def test_join_public_room(self) -> None: """ Test joining a local user to a public room with "JoinRules.PUBLIC" """ - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1303,10 +1900,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_not_member(self): + def test_join_private_room_if_not_member(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE" when server admin is not member of this room. @@ -1315,19 +1912,18 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.creator, tok=self.creator_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_join_private_room_if_member(self): + def test_join_private_room_if_member(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is member of this room. @@ -1352,21 +1948,20 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. url = f"/_synapse/admin/v1/join/{private_room_id}" - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1376,10 +1971,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_owner(self): + def test_join_private_room_if_owner(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is owner of this room. @@ -1388,16 +1983,15 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.admin_user, tok=self.admin_user_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1407,10 +2001,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_context_as_non_admin(self): + def test_context_as_non_admin(self) -> None: """ Test that, without being admin, one cannot use the context admin API """ @@ -1441,12 +2035,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals( - 403, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_context_as_admin(self): + def test_context_as_admin(self) -> None: """ Test that, as admin, we can find the context of an event without having joined the room. """ @@ -1473,7 +2065,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -1502,7 +2094,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1519,7 +2111,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): self.public_room_id ) - def test_public_room(self): + def test_public_room(self) -> None: """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -1532,7 +2124,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -1544,7 +2136,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_private_room(self): + def test_private_room(self) -> None: """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( self.creator, @@ -1559,7 +2151,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -1572,7 +2164,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_other_user(self): + def test_other_user(self) -> None: """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -1585,7 +2177,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -1597,7 +2189,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.second_tok, ) - def test_not_enough_power(self): + def test_not_enough_power(self) -> None: """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -1619,17 +2211,245 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # We expect this to fail with a 400 as there are no room admins. + # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins. # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", ) +class BlockRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/block" + + @parameterized.expand([("PUT",), ("GET",)]) + def test_requester_is_no_admin(self, method: str) -> None: + """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" + + channel = self.make_request( + method, + self.url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand([("PUT",), ("GET",)]) + def test_room_is_not_valid(self, method: str) -> None: + """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" + + channel = self.make_request( + method, + self.url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_block_is_not_valid(self) -> None: + """If parameter `block` is not valid, return an error.""" + + # `block` is not valid + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + # `block` is not set + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no content is send + channel = self.make_request( + "PUT", + self.url % self.room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + def test_block_room(self) -> None: + """Test that block a room is successful.""" + + def _request_and_test_block_room(room_id: str) -> None: + self._is_blocked(room_id, expect=False) + channel = self.make_request( + "PUT", + self.url % room_id, + content={"block": True}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self._is_blocked(room_id, expect=True) + + # known internal room + _request_and_test_block_room(self.room_id) + + # unknown internal room + _request_and_test_block_room("!unknown:test") + + # unknown remote room + _request_and_test_block_room("!unknown:remote") + + def test_block_room_twice(self) -> None: + """Test that block a room that is already blocked is successful.""" + + self._is_blocked(self.room_id, expect=False) + for _ in range(2): + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": True}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self._is_blocked(self.room_id, expect=True) + + def test_unblock_room(self) -> None: + """Test that unblock a room is successful.""" + + def _request_and_test_unblock_room(room_id: str) -> None: + self._block_room(room_id) + + channel = self.make_request( + "PUT", + self.url % room_id, + content={"block": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self._is_blocked(room_id, expect=False) + + # known internal room + _request_and_test_unblock_room(self.room_id) + + # unknown internal room + _request_and_test_unblock_room("!unknown:test") + + # unknown remote room + _request_and_test_unblock_room("!unknown:remote") + + def test_unblock_room_twice(self) -> None: + """Test that unblock a room that is not blocked is successful.""" + + self._block_room(self.room_id) + for _ in range(2): + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self._is_blocked(self.room_id, expect=False) + + def test_get_blocked_room(self) -> None: + """Test get status of a blocked room""" + + def _request_blocked_room(room_id: str) -> None: + self._block_room(room_id) + + channel = self.make_request( + "GET", + self.url % room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + + # known internal room + _request_blocked_room(self.room_id) + + # unknown internal room + _request_blocked_room("!unknown:test") + + # unknown remote room + _request_blocked_room("!unknown:remote") + + def test_get_unblocked_room(self) -> None: + """Test get status of a unblocked room""" + + def _request_unblocked_room(room_id: str) -> None: + self._is_blocked(room_id, expect=False) + + channel = self.make_request( + "GET", + self.url % room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self.assertNotIn("user_id", channel.json_body) + + # known internal room + _request_unblocked_room(self.room_id) + + # unknown internal room + _request_unblocked_room("!unknown:test") + + # unknown remote room + _request_unblocked_room("!unknown:remote") + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self._store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _block_room(self, room_id: str) -> None: + """Block a room in database""" + self.get_success(self._store.block_room(room_id, self.other_user)) + self._is_blocked(room_id, expect=True) + + PURGE_TABLES = [ "current_state_events", "event_backward_extremities", diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index fbceba3254..3c59f5f766 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from http import HTTPStatus from typing import List +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, room, sync +from synapse.server import HomeServer from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -33,7 +37,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -48,14 +52,18 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/send_server_notice" - def test_no_auth(self): + def test_no_auth(self) -> None: """Try to send a server notice without authentication.""" channel = self.make_request("POST", self.url) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """If the user is not a server admin, an error is returned.""" channel = self.make_request( "POST", @@ -63,12 +71,16 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_does_not_exist(self): - """Tests that a lookup for a user that does not exist returns a 404""" + def test_user_does_not_exist(self) -> None: + """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" channel = self.make_request( "POST", self.url, @@ -76,13 +88,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": "@unknown_person:test", "content": ""}, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( "POST", @@ -94,13 +106,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "Server notices can only be sent to local users", channel.json_body["error"] ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """If parameters are invalid, an error is returned.""" # no content, no user @@ -110,7 +122,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) # no content @@ -121,7 +133,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no body @@ -132,7 +144,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": ""}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'body' not in content", channel.json_body["error"]) @@ -144,11 +156,11 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": {"body": ""}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) - def test_server_notice_disabled(self): + def test_server_notice_disabled(self) -> None: """Tests that server returns error if server notice is disabled""" channel = self.make_request( "POST", @@ -160,14 +172,14 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Server notices are not enabled on this server", channel.json_body["error"] ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice(self): + def test_send_server_notice(self) -> None: """ Tests that sending two server notices is successfully, the server uses the same room and do not send messages twice. @@ -185,7 +197,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -216,7 +228,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has no new invites or memberships self._check_invite_and_join_status(self.other_user, 0, 1) @@ -231,7 +243,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(messages[1]["sender"], "@notices:test") @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_leave_room(self): + def test_send_server_notice_leave_room(self) -> None: """ Tests that sending a server notices is successfully. The user leaves the room and the second message appears @@ -250,7 +262,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -293,7 +305,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -315,7 +327,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertNotEqual(first_room_id, second_room_id) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_delete_room(self): + def test_send_server_notice_delete_room(self) -> None: """ Tests that the user get server notice in a new room after the first server notice room was deleted. @@ -333,7 +345,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -382,7 +394,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -405,7 +417,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> RoomsForUser: + ) -> List[RoomsForUser]: """Check invite and room membership status of a user. Args @@ -440,7 +452,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=token ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the messages room = channel.json_body["rooms"]["join"][room_id] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index ece89a65ac..f6e85fdaad 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,13 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import List, Optional -import json -from typing import Any, Dict, List, Optional +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import SMALL_PNG @@ -30,7 +34,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -41,30 +45,38 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/statistics/users/media" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list users without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( "GET", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -75,8 +87,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative from channel = self.make_request( @@ -85,7 +101,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -95,7 +115,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -105,7 +129,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -115,7 +143,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -125,7 +157,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -135,7 +171,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -145,10 +185,14 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of media with limit """ @@ -160,13 +204,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["users"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of media with a defined starting point (from) """ @@ -178,13 +222,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["users"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of media with a defined starting point and limit """ @@ -196,13 +240,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) self._check_fields(channel.json_body["users"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -218,7 +262,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -231,7 +275,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -244,7 +288,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -257,12 +301,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_no_media(self): + def test_no_media(self) -> None: """ Tests that a normal lookup for statistics is successfully if users have no media created @@ -274,11 +318,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) - def test_order_by(self): + def test_order_by(self) -> None: """ Testing order list with parameter `order_by` """ @@ -356,7 +400,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): "b", ) - def test_from_until_ts(self): + def test_from_until_ts(self) -> None: """ Testing filter by time with parameters `from_ts` and `until_ts` """ @@ -371,7 +415,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -381,7 +425,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -396,7 +440,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -405,10 +449,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) - def test_search_term(self): + def test_search_term(self) -> None: self._create_users_with_media(20, 1) # check without filter get all users @@ -417,7 +461,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -426,7 +470,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -435,7 +479,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -445,10 +489,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) - def _create_users_with_media(self, number_users: int, media_per_user: int): + def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: """ Create a number of users with a number of media Args: @@ -460,7 +504,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): user_tok = self.login("foo_user_%s" % i, "pass") self._create_media(user_tok, media_per_user) - def _create_media(self, user_token: str, number_media: int): + def _create_media(self, user_token: str, number_media: int) -> None: """ Create a number of media for a specific user Args: @@ -471,10 +515,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=200 + upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK ) - def _check_fields(self, content: List[Dict[str, Any]]): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in content Args: content: List that is checked for content @@ -487,7 +531,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): def _order_test( self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None - ): + ) -> None: """Request the list of users in a certain order. Assert that order is what we expect Args: @@ -505,7 +549,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 25e8d6cf27..4fedd5fd08 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,6 +17,7 @@ import hmac import os import urllib.parse from binascii import unhexlify +from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock, patch @@ -74,7 +75,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -106,7 +107,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -114,7 +115,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -126,18 +127,18 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -152,7 +153,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, @@ -160,11 +161,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -176,24 +177,24 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -214,7 +215,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be an empty body present channel = self.make_request("POST", self.url, {}) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -224,28 +225,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present channel = self.make_request("POST", self.url, {"nonce": nonce()}) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -256,28 +257,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce(), "username": "a"} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": "a", "password": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -293,7 +294,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -307,22 +308,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob1", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -331,22 +332,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob2", "displayname": None, "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -355,22 +356,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob3", "displayname": "", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -378,22 +379,22 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob4", "displayname": "Bob's Name", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -425,7 +426,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, @@ -433,11 +434,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -461,7 +462,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -473,7 +474,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): @@ -489,7 +490,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -503,7 +504,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_user_id: Optional[str], search_term: str, search_field: Optional[str] = "name", - expected_http_code: Optional[int] = 200, + expected_http_code: Optional[int] = HTTPStatus.OK, ): """Search for a user and check that the returned user's id is a match @@ -525,7 +526,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != 200: + if expected_http_code != HTTPStatus.OK: return # Check that users were returned @@ -586,7 +587,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -596,7 +597,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -606,7 +607,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid deactivated @@ -616,7 +617,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # unkown order_by @@ -626,7 +627,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -636,7 +637,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def test_limit(self): @@ -654,7 +655,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -675,7 +676,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -696,7 +697,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -719,7 +720,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -732,7 +733,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -745,7 +746,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -759,7 +760,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -862,14 +863,14 @@ class UsersListTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["name"] for row in channel.json_body["users"]] self.assertEqual(expected_user_list, returned_order) self._check_fields(channel.json_body["users"]) - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: List[JsonDict]): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -936,7 +937,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -947,7 +948,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -957,12 +958,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that deactivation for a user that does not exist returns a 404 + Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND """ channel = self.make_request( @@ -971,7 +972,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_erase_is_not_bool(self): @@ -986,18 +987,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that deactivation for a user that is not a local returns a 400 + Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) def test_deactivate_user_erase_true(self): @@ -1012,7 +1013,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1027,7 +1028,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1036,7 +1037,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1057,7 +1058,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1072,7 +1073,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1081,7 +1082,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1111,7 +1112,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1126,7 +1127,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1135,7 +1136,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1169,14 +1170,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.admin_user, device_id=None, valid_until_ms=None ) ) self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.other_user, device_id=None, valid_until_ms=None ) ) @@ -1195,7 +1196,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1205,12 +1206,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ channel = self.make_request( @@ -1219,7 +1220,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1234,7 +1235,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"admin": "not_bool"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # deactivated not bool @@ -1244,7 +1245,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": "not_bool"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not str @@ -1254,7 +1255,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": True}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not length @@ -1264,7 +1265,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": "x" * 513}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # user_type not valid @@ -1274,7 +1275,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"user_type": "new type"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # external_ids not valid @@ -1286,7 +1287,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} }, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1295,7 +1296,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # threepids not valid @@ -1305,7 +1306,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1314,7 +1315,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"address": "value"}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_get_user(self): @@ -1327,7 +1328,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) @@ -1370,7 +1371,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1433,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1461,9 +1462,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # before limit of monthly active users is reached channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) - if channel.code != 200: + if channel.code != HTTPStatus.OK: raise HttpResponseException( - channel.code, channel.result["reason"], channel.result["body"] + channel.code, channel.result["reason"], channel.json_body ) # Set monthly active users to the limit @@ -1625,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "hahaha"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_set_displayname(self): @@ -1641,7 +1642,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "foobar"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1652,7 +1653,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1674,7 +1675,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1700,7 +1701,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1716,7 +1717,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1732,7 +1733,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": []}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1759,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1778,7 +1779,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1800,7 +1801,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # other user has this two threepids - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1819,7 +1820,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): url_first_user, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1848,7 +1849,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1880,7 +1881,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1899,7 +1900,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1918,7 +1919,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": []}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) @@ -1947,7 +1948,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1973,7 +1974,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2005,7 +2006,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # must fail - self.assertEqual(409, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("External id is already in use.", channel.json_body["error"]) @@ -2016,7 +2017,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2034,7 +2035,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2065,7 +2066,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -2080,7 +2081,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2096,7 +2097,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2123,7 +2124,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -2139,7 +2140,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -2163,7 +2164,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -2172,7 +2173,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -2194,7 +2195,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2204,7 +2205,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2226,7 +2227,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2236,7 +2237,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2255,7 +2256,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2266,7 +2267,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2283,7 +2284,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": UserTypes.SUPPORT}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2294,7 +2295,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2306,7 +2307,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": None}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2317,7 +2318,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2347,7 +2348,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -2360,7 +2361,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -2369,7 +2370,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2394,7 +2395,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -2445,7 +2446,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2460,7 +2461,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2474,7 +2475,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2490,7 +2491,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2506,7 +2507,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2527,7 +2528,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) @@ -2574,7 +2575,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) @@ -2603,7 +2604,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2618,12 +2619,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" channel = self.make_request( @@ -2632,12 +2633,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" @@ -2647,7 +2648,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): @@ -2662,7 +2663,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) # Register the pusher @@ -2693,7 +2694,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) for p in channel.json_body["pushers"]: @@ -2732,7 +2733,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -2746,12 +2747,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_does_not_exist(self, method: str): - """Tests that a lookup for a user that does not exist returns a 404""" + """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( method, @@ -2759,12 +2760,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_is_not_local(self, method: str): - """Tests that a lookup for a user that is not a local returns a 400""" + """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( @@ -2773,7 +2774,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_limit_GET(self): @@ -2789,7 +2790,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -2808,7 +2809,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) @@ -2825,7 +2826,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -2844,7 +2845,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) @@ -2861,7 +2862,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) @@ -2880,7 +2881,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10) @@ -2894,7 +2895,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -2904,7 +2905,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit @@ -2914,7 +2915,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -2924,7 +2925,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): @@ -2947,7 +2948,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2960,7 +2961,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2973,7 +2974,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2987,7 +2988,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -3004,7 +3005,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) @@ -3019,7 +3020,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) @@ -3036,7 +3037,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["media"])) self.assertNotIn("next_token", channel.json_body) @@ -3062,7 +3063,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertCountEqual(channel.json_body["deleted_media"], media_ids) @@ -3207,7 +3208,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=200 + upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK ) # Extract media ID from the response @@ -3225,16 +3226,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - 200, + HTTPStatus.OK, channel.code, msg=( - f"Expected to receive a 200 on accessing media: {server_and_media_id}" + f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" ), ) return media_id - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: List[JsonDict]): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -3274,7 +3275,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_media_list)) returned_order = [row["media_id"] for row in channel.json_body["media"]] @@ -3310,14 +3311,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3326,7 +3327,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3351,7 +3352,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -3363,21 +3364,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -3388,23 +3389,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3415,23 +3416,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3459,7 +3460,10 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Now unaccept it and check that we can't send an event self.get_success(self.store.user_set_consent_version(self.other_user, "0.0")) self.helper.send_event( - room_id, "com.example.test", tok=self.other_user_tok, expect_code=403 + room_id, + "com.example.test", + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Login in as the user @@ -3477,7 +3481,10 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Trying to join as the other user should fail due to reaching MAU limit. self.helper.join( - room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403 + room_id, + user=self.other_user, + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Logging in as the other user and joining a room should work, even @@ -3512,7 +3519,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3527,12 +3534,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user2_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = self.url_prefix % "@unknown_person:unknown_domain" @@ -3541,7 +3548,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) def test_get_whois_admin(self): @@ -3553,7 +3560,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3568,7 +3575,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user_token, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3592,32 +3599,35 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + @parameterized.expand(["POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of an user without authentication. """ - channel = self.make_request("POST", self.url) - self.assertEqual(401, channel.code, msg=channel.json_body) + channel = self.make_request(method, self.url) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + @parameterized.expand(["POST", "DELETE"]) + def test_requester_is_not_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") - channel = self.make_request("POST", self.url, access_token=other_user_token) - self.assertEqual(403, channel.code, msg=channel.json_body) + channel = self.make_request(method, self.url, access_token=other_user_token) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand(["POST", "DELETE"]) + def test_user_is_not_local(self, method: str): """ - Tests that shadow-banning for a user that is not a local returns a 400 + Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(400, channel.code, msg=channel.json_body) + channel = self.make_request(method, url, access_token=self.admin_user_tok) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) def test_success(self): """ @@ -3629,13 +3639,24 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) self.assertTrue(result.shadow_banned) + # Un-shadow-ban the user. + channel = self.make_request( + "DELETE", self.url, access_token=self.admin_user_tok + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is no longer shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + class RateLimitTestCase(unittest.HomeserverTestCase): @@ -3663,7 +3684,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) @@ -3679,13 +3700,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" @@ -3695,7 +3716,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( @@ -3707,7 +3728,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) def test_user_is_not_local(self, method: str, error_msg: str): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = ( "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" @@ -3719,7 +3740,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): @@ -3734,7 +3755,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3745,7 +3766,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3756,7 +3777,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3767,7 +3788,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3792,7 +3813,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3806,7 +3827,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3817,7 +3838,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3828,7 +3849,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3838,7 +3859,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3848,7 +3869,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3858,6 +3879,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 4e1c49c28b..7978626e71 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.rest.client import login @@ -33,30 +35,38 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): async def check_username(username): if username == "allowed": return True - raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "User ID already taken.", + errcode=Codes.USER_IN_USE, + ) handler = self.hs.get_registration_handler() handler.check_username = check_username def test_username_available(self): """ - The endpoint should return a 200 response if the username does not exist + The endpoint should return a HTTPStatus.OK response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["available"]) def test_username_unavailable(self): """ - The endpoint should return a 200 response if the username does not exist + The endpoint should return a HTTPStatus.OK response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") self.assertEqual(channel.json_body["error"], "User ID already taken.") diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index e2fcbdc63a..27cb856b0a 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import Optional, Union from twisted.internet.defer import succeed @@ -84,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) channel = self.make_request( "POST", @@ -103,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, + HTTPStatus.UNAUTHORIZED, {"username": "user", "type": "m.login.password", "password": "bar"}, ) @@ -115,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) # Complete the recaptcha step. - self.recaptcha(session, 200) + self.recaptcha(session, HTTPStatus.OK) # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + self.register( + HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}} + ) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) + channel = self.register(HTTPStatus.OK, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -136,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # will be used.) # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} + HTTPStatus.UNAUTHORIZED, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -230,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -241,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -259,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -292,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) + channel = self.delete_devices( + HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]} + ) # Grab the session session = channel.json_body["session"] @@ -302,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. self.delete_devices( - 200, + HTTPStatus.OK, { "devices": ["dev2"], "auth": { @@ -323,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -337,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev2", - 403, + HTTPStatus.FORBIDDEN, { "auth": { "type": "m.login.password", @@ -360,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase): self.login("test", self.user_pass, "dev3") # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) + self.delete_device(self.user_tok, "dev2", HTTPStatus.OK) # Move the clock forward past the validation timeout. self.reactor.advance(6) # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) + channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED) # Grab the session session = channel.json_body["session"] @@ -377,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev3", - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -392,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # due to re-using the previous session. # # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) + self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK) @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) @@ -412,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # check that SSO is offered flows = channel.json_body["flows"] @@ -425,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # and now the delete request should succeed. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, body={"auth": {"session": session_id}}, ) @@ -444,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase): # now call the device deletion API: we should get the option to auth with SSO # and not password. - channel = self.delete_device(user_tok, device_id, 401) + channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -462,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] # we have no particular expectations of ordering here @@ -479,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertIn({"stages": ["m.login.sso"]}, flows) @@ -495,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # ... and the delete op should now fail with a 403 self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + self.user_tok, + self.device_id, + HTTPStatus.FORBIDDEN, + body={"auth": {"session": session_id}}, ) @@ -513,25 +536,56 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) + def use_refresh_token(self, refresh_token: str) -> FakeChannel: + """ + Helper that makes a request to use a refresh token. + """ + return self.make_request( + "POST", + "/_matrix/client/v1/refresh", + {"refresh_token": refresh_token}, + ) + + def is_access_token_valid(self, access_token) -> bool: + """ + Checks whether an access token is valid, returning whether it is or not. + """ + code = self.make_request( + "GET", "/_matrix/client/v3/account/whoami", access_token=access_token + ).code + + # Either 200 or 401 is what we get back; anything else is a bug. + assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} + + return code == HTTPStatus.OK + def test_login_issue_refresh_token(self): """ A login response should include a refresh_token only if asked. """ # Test login - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + } login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertEqual( + login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result + ) self.assertNotIn("refresh_token", login_without_refresh.json_body) login_with_refresh = self.make_request( "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, + "/_matrix/client/r0/login", + {"refresh_token": True, **body}, + ) + self.assertEqual( + login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) @@ -549,20 +603,25 @@ class RefreshAuthTests(unittest.HomeserverTestCase): }, ) self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result + register_without_refresh.code, + HTTPStatus.OK, + register_without_refresh.result, ) self.assertNotIn("refresh_token", register_without_refresh.json_body) register_with_refresh = self.make_request( "POST", - "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", + "/_matrix/client/r0/register", { "username": "test3", "password": self.user_pass, "auth": {"type": LoginType.DUMMY}, + "refresh_token": True, }, ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertEqual( + register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result + ) self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) @@ -570,20 +629,25 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ A refresh token can be used to issue a new access token. """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } login_response = self.make_request( "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn("access_token", refresh_response.json_body) self.assertIn("refresh_token", refresh_response.json_body) self.assertIn("expires_in_ms", refresh_response.json_body) @@ -598,31 +662,223 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response.json_body["refresh_token"], ) - @override_config({"access_token_lifetime": "1m"}) - def test_refresh_token_expiration(self): + @override_config({"refreshable_access_token_lifetime": "1m"}) + def test_refreshable_access_token_expiration(self): """ The access token should have some time as specified in the config. """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } login_response = self.make_request( "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) self.assertApproximates( login_response.json_body["expires_in_ms"], 60 * 1000, 100 ) refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) + access_token = refresh_response.json_body["access_token"] + + # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) + self.reactor.advance(59.0) + # Check that our token is valid + self.assertEqual( + self.make_request( + "GET", "/_matrix/client/v3/account/whoami", access_token=access_token + ).code, + HTTPStatus.OK, + ) + + # Advance 2 more seconds (just past the time of expiry) + self.reactor.advance(2.0) + # Check that our token is invalid + self.assertEqual( + self.make_request( + "GET", "/_matrix/client/v3/account/whoami", access_token=access_token + ).code, + HTTPStatus.UNAUTHORIZED, + ) + + @override_config( + { + "refreshable_access_token_lifetime": "1m", + "nonrefreshable_access_token_lifetime": "10m", + } + ) + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + """ + Tests that the expiry times for refreshable and non-refreshable access + tokens can be different. + """ + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + } + login_response1 = self.make_request( + "POST", + "/_matrix/client/r0/login", + {"refresh_token": True, **body}, + ) + self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result) + self.assertApproximates( + login_response1.json_body["expires_in_ms"], 60 * 1000, 100 + ) + refreshable_access_token = login_response1.json_body["access_token"] + + login_response2 = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result) + nonrefreshable_access_token = login_response2.json_body["access_token"] + + # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) + self.reactor.advance(59.0) + + # Both tokens should still be valid. + self.assertTrue(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 61 s (just past 1 minute, the time of expiry) + self.reactor.advance(2.0) + + # Only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 599 s (just shy of 10 minutes, the time of expiry) + self.reactor.advance(599.0 - 61.0) + + # It's still the case that only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 601 s (just past 10 minutes, the time of expiry) + self.reactor.advance(2.0) + + # Now neither token is valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + + @override_config( + {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} + ) + def test_refresh_token_expiry(self): + """ + The refresh token can be configured to have a limited lifetime. + When that lifetime has ended, the refresh token can no longer be used to + refresh the session. + """ + + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) + refresh_token1 = login_response.json_body["refresh_token"] + + # Advance 119 seconds in the future (just shy of 2 minutes) + self.reactor.advance(119.0) + + # Refresh our session. The refresh token should still JUST be valid right now. + # By doing so, we get a new access token and a new refresh token. + refresh_response = self.use_refresh_token(refresh_token1) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) + self.assertIn( + "refresh_token", + refresh_response.json_body, + "No new refresh token returned after refresh.", + ) + refresh_token2 = refresh_response.json_body["refresh_token"] + + # Advance 121 seconds in the future (just a bit more than 2 minutes) + self.reactor.advance(121.0) + + # Try to refresh our session, but instead notice that the refresh token is + # not valid (it just expired). + refresh_response = self.use_refresh_token(refresh_token2) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) + + @override_config( + { + "refreshable_access_token_lifetime": "2m", + "refresh_token_lifetime": "2m", + "session_lifetime": "3m", + } + ) + def test_ultimate_session_expiry(self): + """ + The session can be configured to have an ultimate, limited lifetime. + """ + + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) + refresh_token = login_response.json_body["refresh_token"] + + # Advance shy of 2 minutes into the future + self.reactor.advance(119.0) + + # Refresh our session. The refresh token should still be valid right now. + refresh_response = self.use_refresh_token(refresh_token) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) + self.assertIn( + "refresh_token", + refresh_response.json_body, + "No new refresh token returned after refresh.", + ) + # Notice that our access token lifetime has been diminished to match the + # session lifetime. + # 3 minutes - 119 seconds = 61 seconds. + self.assertEqual(refresh_response.json_body["expires_in_ms"], 61_000) + refresh_token = refresh_response.json_body["refresh_token"] + + # Advance 61 seconds into the future. Our session should have expired + # now, because we've had our 3 minutes. + self.reactor.advance(61.0) + + # Try to issue a new, refreshed, access token. + # This should fail because the refresh token's lifetime has also been + # diminished as our session expired. + refresh_response = self.use_refresh_token(refresh_token) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -640,42 +896,49 @@ class RefreshAuthTests(unittest.HomeserverTestCase): |-> fourth_refresh (fails) """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } login_response = self.make_request( "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result + first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result ) # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result + second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result ) # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result + third_refresh_response.code, + HTTPStatus.UNAUTHORIZED, + third_refresh_response.result, ) # The associated access token should also be invalid @@ -684,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/account/whoami", access_token=first_refresh_response.json_body["access_token"], ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result + ) # But all other tokens should work (they will expire after some time) for access_token in [ @@ -694,24 +959,28 @@ class RefreshAuthTests(unittest.HomeserverTestCase): whoami_response = self.make_request( "GET", "/_matrix/client/r0/account/whoami", access_token=access_token ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.OK, whoami_response.result + ) # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result + fourth_refresh_response.code, + HTTPStatus.FORBIDDEN, + fourth_refresh_response.result, ) # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result + fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index b9e3602552..249808b031 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"experimental_features": {"msc3244_enabled": False}}) def test_get_does_not_include_msc3244_fields_when_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def test_get_does_include_msc3244_fields_when_enabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index d2181ea907..aca03afd0e 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import directory, login, room +from synapse.server import HomeServer from synapse.types import RoomAlias +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_membership_for_aliases"] = True @@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + """Create two local users and access tokens for them. + One of them creates a room.""" self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_state_event_not_in_room(self): + def test_state_event_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_endpoint_not_in_room(self): + def test_directory_endpoint_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_directory(403) + self.set_alias_via_directory(HTTPStatus.FORBIDDEN) - def test_state_event_in_room_too_long(self): + def test_state_event_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) + self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256) - def test_directory_in_room_too_long(self): + def test_directory_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) + self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256) @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): + def test_state_event_user_in_v5_room(self) -> None: """Test that a regular user can add alias events before room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(200) + self.set_alias_via_state_event(HTTPStatus.OK) @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): + def test_state_event_v6_room(self) -> None: """Test that a regular user can *not* add alias events from room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_in_room(self): + def test_directory_in_room(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(200) + self.set_alias_via_directory(HTTPStatus.OK) - def test_room_creation_too_long(self): + def test_room_creation_too_long(self) -> None: url = "/_matrix/client/r0/createRoom" # We use deliberately a localpart under the length threshold so @@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_room_creation(self): + def test_room_creation(self) -> None: url = "/_matrix/client/r0/createRoom" # Check with an alias of allowed length. There should already be @@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_alias_via_directory(self) -> None: + # Add an alias for the room. We must be joined to do so. + self.ensure_user_joined_room() + alias = self.set_alias_via_directory(HTTPStatus.OK) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_nonexistant_alias(self) -> None: + # Check that no alias exists + alias = "#potato:test" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) - def set_alias_via_state_event(self, expected_code, alias_length=5): + def set_alias_via_state_event( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> None: url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( self.room_id, self.hs.hostname, @@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_code, channel.result) - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + def set_alias_via_directory( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> str: + alias = self.random_alias(alias_length) + url = "/_matrix/client/r0/directory/room/%s" % alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase): "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) + return alias - def random_alias(self, length): + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() - def ensure_user_left_room(self): + def ensure_user_left_room(self) -> None: self.ensure_membership("leave") - def ensure_user_joined_room(self): + def ensure_user_joined_room(self) -> None: self.ensure_membership("join") - def ensure_membership(self, membership): + def ensure_membership(self, membership: str) -> None: try: if membership == "leave": self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a63f04bd41..19f5e46537 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -812,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase): jwt_secret = "secret" jwt_algorithm = "HS256" + base_config = { + "enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + } - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_secret - self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm - return self.hs + def default_config(self): + config = super().default_config() + + # If jwt_config has been defined (eg via @override_config), don't replace it. + if config.get("jwt_config") is None: + config["jwt_config"] = self.base_config + + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. @@ -876,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) + @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) def test_login_iss(self): """Test validating the issuer claim.""" # A valid issuer. @@ -916,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) + @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) def test_login_aud(self): """Test validating the audience claim.""" # A valid audience. @@ -959,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) + def test_login_default_sub(self): + """Test reading user ID from the default subject claim.""" + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) + def test_login_custom_sub(self): + """Test reading user ID from a custom subject claim.""" + channel = self.jwt_login({"username": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + def test_login_no_token(self): params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) @@ -1021,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_pubkey - self.hs.config.jwt.jwt_algorithm = "RS256" - return self.hs + def default_config(self): + config = super().default_config() + config["jwt_config"] = { + "enabled": True, + "secret": self.jwt_pubkey, + "algorithm": "RS256", + } + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 78c2fb86b9..397c12c2a6 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1,4 +1,5 @@ # Copyright 2019 New Vector Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room +from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel @@ -28,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, + sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -46,6 +48,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -91,6 +95,49 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_deny_invalid_event(self): + """Test that we deny relations on non-existant events""" + channel = self._send_relation( + RelationTypes.ANNOTATION, + EventTypes.Message, + parent_id="foo", + content={"body": "foo", "msgtype": "m.text"}, + ) + self.assertEquals(400, channel.code, channel.json_body) + + # Unless that event is referenced from another event! + self.get_success( + self.hs.get_datastore().db_pool.simple_insert( + table="event_relations", + values={ + "event_id": "bar", + "relates_to_id": "foo", + "relation_type": RelationTypes.THREAD, + }, + desc="test_deny_invalid_event", + ) + ) + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + parent_id="foo", + content={"body": "foo", "msgtype": "m.text"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + def test_deny_invalid_room(self): + """Test that we deny relations on non-existant events""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Attempt to send an annotation to that event. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_deny_double_react(self): """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") @@ -99,6 +146,25 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(400, channel.code, channel.json_body) + def test_deny_forked_thread(self): + """It is invalid to start a thread off a thread.""" + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=self.parent_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + parent_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=parent_id, + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -389,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_aggregation_get_event(self): - """Test that annotations, references, and threads get correctly bundled when - getting the parent event. - """ - + def test_bundled_aggregations(self): + """Test that annotations, references, and threads get correctly bundled.""" + # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -420,43 +484,169 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] + def assert_bundle(actual): + """Assert the expected values of the bundled aggregations.""" + + # Ensure the fields are as expected. + self.assertCountEqual( + actual.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) + + # Check the values of each field. + self.assertEquals( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + actual[RelationTypes.ANNOTATION], + ) + + self.assertEquals( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + actual[RelationTypes.REFERENCE], + ) + + self.assertEquals( + 2, + actual[RelationTypes.THREAD].get("count"), + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "user_id": self.user_id, + }, + actual[RelationTypes.THREAD].get("latest_event"), + ) + + def _find_and_assert_event(events): + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + break + else: + raise AssertionError(f"Event {self.parent_id} not found in chunk") + assert_bundle(event["unsigned"].get("m.relations")) + + # Request the event directly. channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["unsigned"].get("m.relations")) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + _find_and_assert_event(channel.json_body["chunk"]) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + + # Request sync. + channel = self.make_request("GET", "/sync", access_token=self.user_token) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + _find_and_assert_event(room_timeline["events"]) + + # Note that /relations is tested separately in test_aggregation_get_event_for_thread + # since it needs different data configured. + + def test_aggregation_get_event_for_annotation(self): + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self): + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + self.assertEquals(200, channel.code, channel.json_body) + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - RelationTypes.THREAD: { - "count": 2, - "latest_event": { - "age": 100, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "origin_server_ts": 1600, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "unsigned": {"age": 100}, - "user_id": self.user_id, - }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEquals( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] }, }, ) @@ -607,6 +797,56 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_edit_edit(self): + """Test that an edit cannot be edited.""" + new_body = {"msgtype": "m.text", "body": "Initial edit"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + edit_event_id = channel.json_body["event_id"] + + # Edit the edit event. + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, + }, + parent_id=edit_event_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Request the original event. + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + # The edit to the edit should be ignored. + self.assertEquals(channel.json_body["content"], new_body) + + # The relations information should not include the edit to the edit. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. @@ -703,6 +943,52 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertEquals(channel.json_body["chunk"], []) + def test_unknown_relations(self): + """Unknown relations should be accepted.""" + channel = self._send_relation("m.relation.test", "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEquals( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # When bundling the unknown relation is not included. + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + + # But unknown relations can be directly queried. + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + def _send_relation( self, relation_type: str, @@ -749,3 +1035,65 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token = self.login(localpart, "abc123") return user_id, access_token + + def test_background_update(self): + """Test the event_arbitrary_relations background update.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEquals(200, channel.code, channel.json_body) + annotation_event_id_good = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") + self.assertEquals(200, channel.code, channel.json_body) + annotation_event_id_bad = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_event_id = channel.json_body["event_id"] + + # Clean-up the table as if the inserts did not happen during event creation. + self.get_success( + self.store.db_pool.simple_delete_many( + table="event_relations", + column="event_id", + iterable=(annotation_event_id_bad, thread_event_id), + keyvalues={}, + desc="RelationsTestCase.test_background_update", + ) + ) + + # Only the "good" annotation should be found. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEquals( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good], + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "event_arbitrary_relations", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + self.wait_for_background_updates() + + # The "good" annotation and the thread should be found, but not the "bad" + # annotation. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertCountEqual( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good, thread_event_id], + ) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index ec0979850b..1af5e5cee5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,10 +19,21 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + AnyStr, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, + Tuple, + overload, +) from unittest.mock import patch import attr +from typing_extensions import Literal from twisted.web.resource import Resource from twisted.web.server import Site @@ -45,6 +56,32 @@ class RestHelper: site = attr.ib(type=Site) auth_user_id = attr.ib() + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: Literal[200] = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> str: + ... + + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: int = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> Optional[str]: + ... + def create_room_as( self, room_creator: Optional[str] = None, @@ -53,10 +90,8 @@ class RestHelper: tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ) -> str: + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ) -> Optional[str]: """ Create a room. @@ -99,6 +134,8 @@ class RestHelper: if expect_code == 200: return channel.json_body["room_id"] + else: + return None def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( @@ -168,7 +205,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, - expect_errcode: str = None, + expect_errcode: Optional[str] = None, ) -> None: """ Send a membership state event into a room. @@ -227,9 +264,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if body is None: body = "body_text_here" @@ -254,9 +289,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -418,7 +451,7 @@ class RestHelper: path, content=image_data, access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], + custom_headers=[("Content-Length", str(image_length))], ) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( @@ -503,7 +536,7 @@ class RestHelper: went. """ - cookies = {} + cookies: Dict[str, str] = {} # if we're doing a ui auth, hit the ui auth redirect endpoint if ui_auth_session_id: @@ -625,7 +658,13 @@ class RestHelper: # hit the redirect url again with the right Host header, which should now issue # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] + def get_location(channel: FakeChannel) -> str: + location_values = channel.headers.getRawHeaders("Location") + # Keep mypy happy by asserting that location_values is nonempty + assert location_values + return location_values[0] + + location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( self.hs.get_reactor(), @@ -639,7 +678,7 @@ class RestHelper: assert channel.code == 302 channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] + return get_location(channel) def initiate_sso_ui_auth( self, ui_auth_session_id: str, cookies: MutableMapping[str, str] diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 09504a485f..913bc530aa 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.media.v1.filepath import MediaFilePaths +import inspect +import os +from typing import Iterable + +from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check from tests import unittest @@ -236,3 +240,356 @@ class MediaFilePathsTestCase(unittest.TestCase): "/media_store/url_cache_thumbnails/Ge", ], ) + + def test_server_name_validation(self): + """Test validation of server names""" + self._test_path_validation( + [ + "remote_media_filepath_rel", + "remote_media_filepath", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "remote_media_thumbnail_rel_legacy", + "remote_media_thumbnail_dir", + ], + parameter="server_name", + valid_values=[ + "matrix.org", + "matrix.org:8448", + "matrix-federation.matrix.org", + "matrix-federation.matrix.org:8448", + "10.1.12.123", + "10.1.12.123:8448", + "[fd00:abcd::ffff]", + "[fd00:abcd::ffff]:8448", + ], + invalid_values=[ + "/matrix.org", + "matrix.org/..", + "matrix.org\x00", + "", + ".", + "..", + "/", + ], + ) + + def test_file_id_validation(self): + """Test validation of local, remote and legacy URL cache file / media IDs""" + # File / media IDs get split into three parts to form paths, consisting of the + # first two characters, next two characters and rest of the ID. + valid_file_ids = [ + "GerZNDnDZVjsOtardLuwfIBg", + # Unexpected, but produces an acceptable path: + "GerZN", # "N" becomes the last directory + ] + invalid_file_ids = [ + "/erZNDnDZVjsOtardLuwfIBg", + "Ge/ZNDnDZVjsOtardLuwfIBg", + "GerZ/DnDZVjsOtardLuwfIBg", + "GerZ/..", + "G\x00rZNDnDZVjsOtardLuwfIBg", + "Ger\x00NDnDZVjsOtardLuwfIBg", + "GerZNDnDZVjsOtardLuwfIBg\x00", + "", + "Ge", + "GerZ", + "GerZ.", + "..rZNDnDZVjsOtardLuwfIBg", + "Ge..NDnDZVjsOtardLuwfIBg", + "GerZ..", + "GerZ/", + ] + + self._test_path_validation( + [ + "local_media_filepath_rel", + "local_media_filepath", + "local_media_thumbnail_rel", + "local_media_thumbnail", + "local_media_thumbnail_dir", + # Legacy URL cache media IDs + "url_cache_filepath_rel", + "url_cache_filepath", + # `url_cache_filepath_dirs_to_delete` is tested below. + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + "url_cache_thumbnail_directory_rel", + "url_cache_thumbnail_directory", + "url_cache_thumbnail_dirs_to_delete", + ], + parameter="media_id", + valid_values=valid_file_ids, + invalid_values=invalid_file_ids, + ) + + # `url_cache_filepath_dirs_to_delete` ignores what would be the last path + # component, so only the first 4 characters matter. + self._test_path_validation( + [ + "url_cache_filepath_dirs_to_delete", + ], + parameter="media_id", + valid_values=valid_file_ids, + invalid_values=[ + "/erZNDnDZVjsOtardLuwfIBg", + "Ge/ZNDnDZVjsOtardLuwfIBg", + "G\x00rZNDnDZVjsOtardLuwfIBg", + "Ger\x00NDnDZVjsOtardLuwfIBg", + "", + "Ge", + "..rZNDnDZVjsOtardLuwfIBg", + "Ge..NDnDZVjsOtardLuwfIBg", + ], + ) + + self._test_path_validation( + [ + "remote_media_filepath_rel", + "remote_media_filepath", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "remote_media_thumbnail_rel_legacy", + "remote_media_thumbnail_dir", + ], + parameter="file_id", + valid_values=valid_file_ids, + invalid_values=invalid_file_ids, + ) + + def test_url_cache_media_id_validation(self): + """Test validation of URL cache media IDs""" + self._test_path_validation( + [ + "url_cache_filepath_rel", + "url_cache_filepath", + # `url_cache_filepath_dirs_to_delete` only cares about the date prefix + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + "url_cache_thumbnail_directory_rel", + "url_cache_thumbnail_directory", + "url_cache_thumbnail_dirs_to_delete", + ], + parameter="media_id", + valid_values=[ + "2020-01-02_GerZNDnDZVjsOtar", + "2020-01-02_G", # Unexpected, but produces an acceptable path + ], + invalid_values=[ + "2020-01-02", + "2020-01-02-", + "2020-01-02-.", + "2020-01-02-..", + "2020-01-02-/", + "2020-01-02-/GerZNDnDZVjsOtar", + "2020-01-02-GerZNDnDZVjsOtar/..", + "2020-01-02-GerZNDnDZVjsOtar\x00", + ], + ) + + def test_content_type_validation(self): + """Test validation of thumbnail content types""" + self._test_path_validation( + [ + "local_media_thumbnail_rel", + "local_media_thumbnail", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "remote_media_thumbnail_rel_legacy", + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + ], + parameter="content_type", + valid_values=[ + "image/jpeg", + ], + invalid_values=[ + "", # ValueError: not enough values to unpack + "image/jpeg/abc", # ValueError: too many values to unpack + "image/jpeg\x00", + ], + ) + + def test_thumbnail_method_validation(self): + """Test validation of thumbnail methods""" + self._test_path_validation( + [ + "local_media_thumbnail_rel", + "local_media_thumbnail", + "remote_media_thumbnail_rel", + "remote_media_thumbnail", + "url_cache_thumbnail_rel", + "url_cache_thumbnail", + ], + parameter="method", + valid_values=[ + "crop", + "scale", + ], + invalid_values=[ + "/scale", + "scale/..", + "scale\x00", + "/", + ], + ) + + def _test_path_validation( + self, + methods: Iterable[str], + parameter: str, + valid_values: Iterable[str], + invalid_values: Iterable[str], + ): + """Test that the specified methods validate the named parameter as expected + + Args: + methods: The names of `MediaFilePaths` methods to test + parameter: The name of the parameter to test + valid_values: A list of parameter values that are expected to be accepted + invalid_values: A list of parameter values that are expected to be rejected + + Raises: + AssertionError: If a value was accepted when it should have failed + validation. + ValueError: If a value failed validation when it should have been accepted. + """ + for method in methods: + get_path = getattr(self.filepaths, method) + + parameters = inspect.signature(get_path).parameters + kwargs = { + "server_name": "matrix.org", + "media_id": "GerZNDnDZVjsOtardLuwfIBg", + "file_id": "GerZNDnDZVjsOtardLuwfIBg", + "width": 800, + "height": 600, + "content_type": "image/jpeg", + "method": "scale", + } + + if get_path.__name__.startswith("url_"): + kwargs["media_id"] = "2020-01-02_GerZNDnDZVjsOtar" + + kwargs = {k: v for k, v in kwargs.items() if k in parameters} + kwargs.pop(parameter) + + for value in valid_values: + kwargs[parameter] = value + get_path(**kwargs) + # No exception should be raised + + for value in invalid_values: + with self.assertRaises(ValueError): + kwargs[parameter] = value + path_or_list = get_path(**kwargs) + self.fail( + f"{value!r} unexpectedly passed validation: " + f"{method} returned {path_or_list!r}" + ) + + +class MediaFilePathsJailTestCase(unittest.TestCase): + def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None: + """Passes a relative path through the jail check. + + Args: + filepaths: The `MediaFilePaths` instance. + path: A path relative to the media store directory. + + Raises: + ValueError: If the jail check fails. + """ + + @_wrap_with_jail_check(relative=True) + def _make_relative_path(self: MediaFilePaths, path: str) -> str: + return path + + _make_relative_path(filepaths, path) + + def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None: + """Passes an absolute path through the jail check. + + Args: + filepaths: The `MediaFilePaths` instance. + path: A path relative to the media store directory. + + Raises: + ValueError: If the jail check fails. + """ + + @_wrap_with_jail_check(relative=False) + def _make_absolute_path(self: MediaFilePaths, path: str) -> str: + return os.path.join(self.base_path, path) + + _make_absolute_path(filepaths, path) + + def test_traversal_inside(self) -> None: + """Test the jail check for paths that stay within the media directory.""" + # Despite the `../`s, these paths still lie within the media directory and it's + # expected for the jail check to allow them through. + # These paths ought to trip the other checks in place and should never be + # returned. + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar" + self._check_relative_path(filepaths, path) + self._check_absolute_path(filepaths, path) + + def test_traversal_outside(self) -> None: + """Test that the jail check fails for paths that escape the media directory.""" + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar" + with self.assertRaises(ValueError): + self._check_relative_path(filepaths, path) + with self.assertRaises(ValueError): + self._check_absolute_path(filepaths, path) + + def test_traversal_reentry(self) -> None: + """Test the jail check for paths that exit and re-enter the media directory.""" + # These paths lie outside the media directory if it is a symlink, and inside + # otherwise. Ideally the check should fail, but this proves difficult. + # This test documents the behaviour for this edge case. + # These paths ought to trip the other checks in place and should never be + # returned. + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar" + self._check_relative_path(filepaths, path) + self._check_absolute_path(filepaths, path) + + def test_symlink(self) -> None: + """Test that a symlink does not cause the jail check to fail.""" + media_store_path = self.mktemp() + + # symlink the media store directory + os.symlink("/mnt/synapse/media_store", media_store_path) + + # Test that relative and absolute paths don't trip the check + # NB: `media_store_path` is a relative path + filepaths = MediaFilePaths(media_store_path) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + filepaths = MediaFilePaths(os.path.abspath(media_store_path)) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + def test_symlink_subdirectory(self) -> None: + """Test that a symlinked subdirectory does not cause the jail check to fail.""" + media_store_path = self.mktemp() + os.mkdir(media_store_path) + + # symlink `url_cache/` + os.symlink( + "/mnt/synapse/media_store_url_cache", + os.path.join(media_store_path, "url_cache"), + ) + + # Test that relative and absolute paths don't trip the check + # NB: `media_store_path` is a relative path + filepaths = MediaFilePaths(media_store_path) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + filepaths = MediaFilePaths(os.path.abspath(media_store_path)) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") diff --git a/tests/server.py b/tests/server.py index 103351b487..40cf5b12c3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,7 +16,17 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union +from typing import ( + AnyStr, + Callable, + Dict, + Iterable, + MutableMapping, + Optional, + Tuple, + Type, + Union, +) import attr from typing_extensions import Deque @@ -217,14 +227,12 @@ def make_request( path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, - request: Request = SynapseRequest, + request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 4b67bd15b7..36c933b9e9 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -66,7 +66,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.db_pool.simple_insert( "background_updates", { - "update_name": "remove_deleted_devices_from_device_inbox", + "update_name": "remove_dead_devices_from_device_inbox", "progress_json": "{}", }, ) @@ -140,7 +140,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.db_pool.simple_insert( "background_updates", { - "update_name": "remove_hidden_devices_from_device_inbox", + "update_name": "remove_dead_devices_from_device_inbox", "progress_json": "{}", }, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index a649e8c618..5ae491ff5a 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -12,11 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from contextlib import contextmanager +from typing import Generator +from twisted.enterprise.adbapi import ConnectionPool +from twisted.internet.defer import ensureDeferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import EventFormatVersions, RoomVersions from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room -from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.server import HomeServer +from synapse.storage.databases.main.events_worker import ( + EVENT_QUEUE_THREADS, + EventsWorkerStore, +) +from synapse.storage.types import Connection +from synapse.util import Clock from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -144,3 +157,127 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + +class DatabaseOutageTestCase(unittest.HomeserverTestCase): + """Test event fetching during a database outage.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store: EventsWorkerStore = hs.get_datastore() + + self.room_id = f"!room:{hs.hostname}" + self.event_ids = [f"event{i}" for i in range(20)] + + self._populate_events() + + def _populate_events(self) -> None: + """Ensure that there are test events in the database. + + When testing with the in-memory SQLite database, all the events are lost during + the simulated outage. + + To ensure consistency between `room_id`s and `event_id`s before and after the + outage, rows are built and inserted manually. + + Upserts are used to handle the non-SQLite case where events are not lost. + """ + self.get_success( + self.store.db_pool.simple_upsert( + "rooms", + {"room_id": self.room_id}, + {"room_version": RoomVersions.V4.identifier}, + ) + ) + + self.event_ids = [f"event{i}" for i in range(20)] + for idx, event_id in enumerate(self.event_ids): + self.get_success( + self.store.db_pool.simple_upsert( + "events", + {"event_id": event_id}, + { + "event_id": event_id, + "room_id": self.room_id, + "topological_ordering": idx, + "stream_ordering": idx, + "type": "test", + "processed": True, + "outlier": False, + }, + ) + ) + self.get_success( + self.store.db_pool.simple_upsert( + "event_json", + {"event_id": event_id}, + { + "room_id": self.room_id, + "json": json.dumps({"type": "test", "room_id": self.room_id}), + "internal_metadata": "{}", + "format_version": EventFormatVersions.V3, + }, + ) + ) + + @contextmanager + def _outage(self) -> Generator[None, None, None]: + """Simulate a database outage. + + Returns: + A context manager. While the context is active, any attempts to connect to + the database will fail. + """ + connection_pool = self.store.db_pool._db_pool + + # Close all connections and shut down the database `ThreadPool`. + connection_pool.close() + + # Restart the database `ThreadPool`. + connection_pool.start() + + original_connection_factory = connection_pool.connectionFactory + + def connection_factory(_pool: ConnectionPool) -> Connection: + raise Exception("Could not connect to the database.") + + connection_pool.connectionFactory = connection_factory # type: ignore[assignment] + try: + yield + finally: + connection_pool.connectionFactory = original_connection_factory + + # If the in-memory SQLite database is being used, all the events are gone. + # Restore the test data. + self._populate_events() + + def test_failure(self) -> None: + """Test that event fetches do not get stuck during a database outage.""" + with self._outage(): + failure = self.get_failure( + self.store.get_event(self.event_ids[0]), Exception + ) + self.assertEqual(str(failure.value), "Could not connect to the database.") + + def test_recovery(self) -> None: + """Test that event fetchers recover after a database outage.""" + with self._outage(): + # Kick off a bunch of event fetches but do not pump the reactor + event_deferreds = [] + for event_id in self.event_ids: + event_deferreds.append(ensureDeferred(self.store.get_event(event_id))) + + # We should have maxed out on event fetcher threads + self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS) + + # All the event fetchers will fail + self.pump() + self.assertEqual(self.store._event_fetch_ongoing, 0) + + for event_deferred in event_deferreds: + failure = self.get_failure(event_deferred, Exception) + self.assertEqual( + str(failure.value), "Could not connect to the database." + ) + + # This next event fetch should succeed + self.get_success(self.store.get_event(self.event_ids[0])) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index f26d5acf9c..329490caad 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -14,35 +14,37 @@ import json import os import tempfile +from typing import List, Optional, cast from unittest.mock import Mock import yaml from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError +from synapse.events import EventBase +from synapse.server import HomeServer from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable -from tests.utils import setup_test_homeserver -class ApplicationServiceStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks +class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def setUp(self): - self.as_yaml_files = [] - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) + super(ApplicationServiceStoreTestCase, self).setUp() + + self.as_yaml_files: List[str] = [] - hs.config.appservice.app_service_config_files = self.as_yaml_files - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = self.as_yaml_files + self.hs.config.caches.event_cache_size = 1 self.as_token = "token1" self.as_url = "some_url" @@ -53,12 +55,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, make_conn(database._database_config, database.engine, "test"), hs + database, + make_conn(database._database_config, database.engine, "test"), + self.hs, ) - def tearDown(self): + def tearDown(self) -> None: # TODO: suboptimal that we need to create files for tests! for f in self.as_yaml_files: try: @@ -66,7 +70,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): except Exception: pass - def _add_appservice(self, as_token, id, url, hs_token, sender): + super(ApplicationServiceStoreTestCase, self).tearDown() + + def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -80,12 +86,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def test_retrieve_unknown_service_token(self): + def test_retrieve_unknown_service_token(self) -> None: service = self.store.get_app_service_by_token("invalid_token") self.assertEquals(service, None) - def test_retrieval_of_service(self): + def test_retrieval_of_service(self) -> None: stored_service = self.store.get_app_service_by_token(self.as_token) + assert stored_service is not None self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) @@ -93,22 +100,18 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) - def test_retrieval_of_all_services(self): + def test_retrieval_of_all_services(self) -> None: services = self.store.get_app_services() self.assertEquals(len(services), 3) -class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.as_yaml_files = [] - - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) +class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super(ApplicationServiceTransactionStoreTestCase, self).setUp() + self.as_yaml_files: List[str] = [] - hs.config.appservice.app_service_config_files = self.as_yaml_files - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = self.as_yaml_files + self.hs.config.caches.event_cache_size = 1 self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -117,21 +120,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"}, ] for s in self.as_list: - yield self._add_service(s["url"], s["token"], s["id"]) + self._add_service(s["url"], s["token"], s["id"]) self.as_yaml_files = [] # We assume there is only one database in these tests - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.database.get_single_database() + db_config = self.hs.config.database.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine, "test"), hs + database, make_conn(db_config, self.engine, "test"), self.hs ) - def _add_service(self, url, as_token, id): + def _add_service(self, url, as_token, id) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -145,13 +148,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state(self, id, state, txn=None): + def _set_state( + self, id: str, state: ApplicationServiceState, txn: Optional[int] = None + ): return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_state(as_id, state, last_txn) " "VALUES(?,?,?)" ), - (id, state, txn), + (id, state.value, txn), ) def _insert_txn(self, as_id, txn_id, events): @@ -169,234 +174,277 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): "INSERT INTO application_services_state(as_id, last_txn, state) " "VALUES(?,?,?)" ), - (as_id, txn_id, ApplicationServiceState.UP), + (as_id, txn_id, ApplicationServiceState.UP.value), ) - @defer.inlineCallbacks - def test_get_appservice_state_none(self): + def test_get_appservice_state_none( + self, + ) -> None: service = Mock(id="999") - state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) + state = self.get_success(self.store.get_appservice_state(service)) self.assertEquals(None, state) - @defer.inlineCallbacks - def test_get_appservice_state_up(self): - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + def test_get_appservice_state_up( + self, + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + ) service = Mock(id=self.as_list[0]["id"]) - state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) + state = self.get_success( + defer.ensureDeferred(self.store.get_appservice_state(service)) + ) self.assertEquals(ApplicationServiceState.UP, state) - @defer.inlineCallbacks - def test_get_appservice_state_down(self): - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) - yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + def test_get_appservice_state_down( + self, + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + ) + self.get_success( + self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + ) service = Mock(id=self.as_list[1]["id"]) - state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) + state = self.get_success(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) - @defer.inlineCallbacks - def test_get_appservices_by_state_none(self): - services = yield defer.ensureDeferred( + def test_get_appservices_by_state_none( + self, + ) -> None: + services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) - @defer.inlineCallbacks - def test_set_appservices_state_down(self): + def test_set_appservices_state_down( + self, + ) -> None: service = Mock(id=self.as_list[1]["id"]) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - rows = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.DOWN,), + rows = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.DOWN.value,), + ) ) self.assertEquals(service.id, rows[0][0]) - @defer.inlineCallbacks - def test_set_appservices_state_multiple_up(self): + def test_set_appservices_state_multiple_up( + self, + ) -> None: service = Mock(id=self.as_list[1]["id"]) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - rows = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.UP,), + rows = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.UP.value,), + ) ) self.assertEquals(service.id, rows[0][0]) - @defer.inlineCallbacks - def test_create_appservice_txn_first(self): + def test_create_appservice_txn_first( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) + txn = self.get_success( + defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks - def test_create_appservice_txn_older_last_txn(self): + def test_create_appservice_txn_older_last_txn( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] - yield self._set_last_txn(service.id, 9643) # AS is falling behind - yield self._insert_txn(service.id, 9644, events) - yield self._insert_txn(service.id, 9645, events) - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) - ) + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) + self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind + self.get_success(self._insert_txn(service.id, 9644, events)) + self.get_success(self._insert_txn(service.id, 9645, events)) + txn = self.get_success(self.store.create_appservice_txn(service, events, [])) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks - def test_create_appservice_txn_up_to_date_last_txn(self): + def test_create_appservice_txn_up_to_date_last_txn( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] - yield self._set_last_txn(service.id, 9643) - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) - ) + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) + self.get_success(self._set_last_txn(service.id, 9643)) + txn = self.get_success(self.store.create_appservice_txn(service, events, [])) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks - def test_create_appservice_txn_up_fuzzing(self): + def test_create_appservice_txn_up_fuzzing( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] - yield self._set_last_txn(service.id, 9643) + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) + self.get_success(self._set_last_txn(service.id, 9643)) # dump in rows with higher IDs to make sure the queries aren't wrong. - yield self._set_last_txn(self.as_list[1]["id"], 119643) - yield self._set_last_txn(self.as_list[2]["id"], 9) - yield self._set_last_txn(self.as_list[3]["id"], 9643) - yield self._insert_txn(self.as_list[1]["id"], 119644, events) - yield self._insert_txn(self.as_list[1]["id"], 119645, events) - yield self._insert_txn(self.as_list[1]["id"], 119646, events) - yield self._insert_txn(self.as_list[2]["id"], 10, events) - yield self._insert_txn(self.as_list[3]["id"], 9643, events) - - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) - ) + self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643)) + self.get_success(self._set_last_txn(self.as_list[2]["id"], 9)) + self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643)) + self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events)) + self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events)) + self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events)) + self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) + self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) + + txn = self.get_success(self.store.create_appservice_txn(service, events, [])) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks - def test_complete_appservice_txn_first_txn(self): + def test_complete_appservice_txn_first_txn( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 1 - yield self._insert_txn(service.id, txn_id, events) - yield defer.ensureDeferred( + self.get_success(self._insert_txn(service.id, txn_id, events)) + self.get_success( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn FROM application_services_state WHERE as_id=?" - ), - (service.id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), + (service.id,), + ) ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), + ) ) self.assertEquals(0, len(res)) - @defer.inlineCallbacks - def test_complete_appservice_txn_existing_in_state_table(self): + def test_complete_appservice_txn_existing_in_state_table( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 - yield self._set_last_txn(service.id, 4) - yield self._insert_txn(service.id, txn_id, events) - yield defer.ensureDeferred( + self.get_success(self._set_last_txn(service.id, 4)) + self.get_success(self._insert_txn(service.id, txn_id, events)) + self.get_success( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn, state FROM application_services_state WHERE as_id=?" - ), - (service.id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), + (service.id,), + ) ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP, res[0][1]) - - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), + self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) + + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), + ) ) self.assertEquals(0, len(res)) - @defer.inlineCallbacks - def test_get_oldest_unsent_txn_none(self): + def test_get_oldest_unsent_txn_none( + self, + ) -> None: service = Mock(id=self.as_list[0]["id"]) - txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) + txn = self.get_success(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) - @defer.inlineCallbacks - def test_get_oldest_unsent_txn(self): + def test_get_oldest_unsent_txn(self) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) + # (ignore needed because Mypy won't allow us to assign to a method otherwise) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] - yield self._insert_txn(self.as_list[1]["id"], 9, other_events) - yield self._insert_txn(service.id, 10, events) - yield self._insert_txn(service.id, 11, other_events) - yield self._insert_txn(service.id, 12, other_events) + self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) + self.get_success(self._insert_txn(service.id, 10, events)) + self.get_success(self._insert_txn(service.id, 11, other_events)) + self.get_success(self._insert_txn(service.id, 12, other_events)) - txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) + txn = self.get_success(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) - @defer.inlineCallbacks - def test_get_appservices_by_state_single(self): - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + def test_get_appservices_by_state_single( + self, + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + ) - services = yield defer.ensureDeferred( + services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) - @defer.inlineCallbacks - def test_get_appservices_by_state_multiple(self): - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) + def test_get_appservices_by_state_multiple( + self, + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + ) + self.get_success( + self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) + ) - services = yield defer.ensureDeferred( + services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) @@ -407,16 +455,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs - - def prepare(self, hs, reactor, clock): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.service = Mock(id="foo") self.store = self.hs.get_datastore() - self.get_success(self.store.set_appservice_state(self.service, "up")) + self.get_success( + self.store.set_appservice_state(self.service, ApplicationServiceState.UP) + ) - def test_get_type_stream_id_for_appservice_no_value(self): + def test_get_type_stream_id_for_appservice_no_value(self) -> None: value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) @@ -427,13 +475,13 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEquals(value, 0) - def test_get_type_stream_id_for_appservice_invalid_type(self): + def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( self.store.get_type_stream_id_for_appservice(self.service, "foobar"), ValueError, ) - def test_set_type_stream_id_for_appservice(self): + def test_set_type_stream_id_for_appservice(self) -> None: read_receipt_value = 1024 self.get_success( self.store.set_type_stream_id_for_appservice( @@ -455,7 +503,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self): + def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), ValueError, @@ -464,12 +512,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs) -> None: super().__init__(database, db_conn, hs) -class ApplicationServiceStoreConfigTestCase(unittest.TestCase): - def _write_config(self, suffix, **kwargs): +class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): + def _write_config(self, suffix, **kwargs) -> str: vals = { "id": "id" + suffix, "url": "url" + suffix, @@ -485,41 +533,33 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f.write(yaml.dump(vals)) return path - @defer.inlineCallbacks - def test_unique_works(self): + def test_unique_works(self) -> None: f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = [f1, f2] - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = [f1, f2] + self.hs.config.caches.event_cache_size = 1 - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine, "test"), hs + database, + make_conn(database._database_config, database.engine, "test"), + self.hs, ) - @defer.inlineCallbacks - def test_duplicate_ids(self): + def test_duplicate_ids(self) -> None: f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = [f1, f2] - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = [f1, f2] + self.hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - hs, + self.hs, ) e = cm.exception @@ -527,24 +567,19 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.assertIn(f2, str(e)) self.assertIn("id", str(e)) - @defer.inlineCallbacks - def test_duplicate_as_tokens(self): + def test_duplicate_as_tokens(self) -> None: f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = [f1, f2] - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = [f1, f2] + self.hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - hs, + self.hs, ) e = cm.exception diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 0da42b5ac5..d77c001506 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,8 +1,26 @@ -from unittest.mock import Mock +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use backported mock for AsyncMock support on Python 3.6. +from mock import Mock + +from twisted.internet.defer import Deferred, ensureDeferred from synapse.storage.background_updates import BackgroundUpdater from tests import unittest +from tests.test_utils import make_awaitable class BackgroundUpdateTestCase(unittest.HomeserverTestCase): @@ -19,11 +37,11 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ) def test_do_background_update(self): - # the time we claim each update takes - duration_ms = 42 + # the time we claim it takes to update one item when running the update + duration_ms = 10 # the target runtime for each bg update - target_background_update_duration_ms = 50000 + target_background_update_duration_ms = 100 store = self.hs.get_datastore() self.get_success( @@ -48,16 +66,14 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() res = self.get_success( - self.updates.do_next_background_update( - target_background_update_duration_ms - ), - by=0.1, + self.updates.do_next_background_update(False), + by=0.01, ) self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE ) # second step: complete the update @@ -74,16 +90,93 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() - result = self.get_success( - self.updates.do_next_background_update(target_background_update_duration_ms) - ) + result = self.get_success(self.updates.do_next_background_update(False)) self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = self.get_success( - self.updates.do_next_background_update(target_background_update_duration_ms) - ) + result = self.get_success(self.updates.do_next_background_update(False)) self.assertTrue(result) self.assertFalse(self.update_handler.called) + + +class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates + # the base test class should have run the real bg updates for us + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) + + self.update_deferred = Deferred() + self.update_handler = Mock(return_value=self.update_deferred) + self.updates.register_background_update_handler( + "test_update", self.update_handler + ) + + # Mock out the AsyncContextManager + self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"]) + self._update_ctx_manager.__aenter__ = Mock( + return_value=make_awaitable(None), + ) + self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None)) + + # Mock out the `update_handler` callback + self._on_update = Mock(return_value=self._update_ctx_manager) + + # Define a default batch size value that's not the same as the internal default + # value (100). + self._default_batch_size = 500 + + # Register the callbacks with more mocks + self.hs.get_module_api().register_background_update_controller_callbacks( + on_update=self._on_update, + min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), + default_batch_size=Mock( + return_value=make_awaitable(self._default_batch_size), + ), + ) + + def test_controller(self): + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": "{}"}, + ) + ) + + # Set the return value for the context manager. + enter_defer = Deferred() + self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) + + # Start the background update. + do_update_d = ensureDeferred(self.updates.do_next_background_update(True)) + + self.pump() + + # `run_update` should have been called, but the update handler won't be + # called until the `enter_defer` (returned by `__aenter__`) is resolved. + self._on_update.assert_called_once_with( + "test_update", + "master", + False, + ) + self.assertFalse(do_update_d.called) + self.assertFalse(self.update_deferred.called) + + # Resolving the `enter_defer` should call the update handler, which then + # blocks. + enter_defer.callback(100) + self.pump() + self.update_handler.assert_called_once_with({}, self._default_batch_size) + self.assertFalse(self.update_deferred.called) + self._update_ctx_manager.__aexit__.assert_not_called() + + # Resolving the update handler deferred should cause the + # `do_next_background_update` to finish and return + self.update_deferred.callback(100) + self.pump() + self._update_ctx_manager.__aexit__.assert_called() + self.get_success(do_update_d) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index b31c5eb5ec..7b7f6c349e 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -664,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(False), by=0.1 ) # Ensure that we did actually take multiple iterations to process the @@ -723,7 +723,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(False), by=0.1 ) # Ensure that we did actually take multiple iterations to process the diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index d2b7b89952..f8d11bac4e 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -13,42 +13,35 @@ # limitations under the License. -from twisted.internet import defer - from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver -class DataStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class DataStoreTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super(DataStoreTestCase, self).setUp() - self.store = hs.get_datastore() + self.store = self.hs.get_datastore() self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" - @defer.inlineCallbacks - def test_get_users_paginate(self): - yield defer.ensureDeferred( - self.store.register_user(self.user.to_string(), "pass") - ) - yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield defer.ensureDeferred( + def test_get_users_paginate(self) -> None: + self.get_success(self.store.register_user(self.user.to_string(), "pass")) + self.get_success(self.store.create_profile(self.user.localpart)) + self.get_success( self.store.set_profile_displayname(self.user.localpart, self.displayname) ) - users, total = yield defer.ensureDeferred( + users, total = self.get_success( self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) - users, total = yield defer.ensureDeferred( + users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a1ba99ff14..d37736edf8 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -11,19 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") - def test_displayname(self): + def test_displayname(self) -> None: self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success( @@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) - def test_avatar_url(self): + def test_avatar_url(self) -> None: self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success( diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index a6be9a1bb1..cfc8098af6 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List +from unittest import mock + from synapse.app.generic_worker import GenericWorkerServer from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database @@ -19,6 +22,22 @@ from synapse.storage.schema import SCHEMA_VERSION from tests.unittest import HomeserverTestCase +def fake_listdir(filepath: str) -> List[str]: + """ + A fake implementation of os.listdir which we can use to mock out the filesystem. + + Args: + filepath: The directory to list files for. + + Returns: + A list of files and folders in the directory. + """ + if filepath.endswith("full_schemas"): + return [str(SCHEMA_VERSION)] + + return ["99_add_unicorn_to_database.sql"] + + class WorkerSchemaTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( @@ -51,7 +70,7 @@ class WorkerSchemaTests(HomeserverTestCase): prepare_database(db_conn, db_pool.engine, self.hs.config) - def test_not_upgraded(self): + def test_not_upgraded_old_schema_version(self): """Test that workers don't start if the DB has an older schema version""" db_pool = self.hs.get_datastore().db_pool db_conn = LoggingDatabaseConnection( @@ -67,3 +86,34 @@ class WorkerSchemaTests(HomeserverTestCase): with self.assertRaises(PrepareDatabaseException): prepare_database(db_conn, db_pool.engine, self.hs.config) + + def test_not_upgraded_current_schema_version_with_outstanding_deltas(self): + """ + Test that workers don't start if the DB is on the current schema version, + but there are still outstanding delta migrations to run. + """ + db_pool = self.hs.get_datastore().db_pool + db_conn = LoggingDatabaseConnection( + db_pool._db_pool.connect(), + db_pool.engine, + "tests", + ) + + # Set the schema version of the database to the current version + cur = db_conn.cursor() + cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION,)) + + db_conn.commit() + + # Path `os.listdir` here to make synapse think that there is a migration + # file ready to be run. + # Note that we can't patch this function for the whole method, else Synapse + # will try to find the file when building the database initially. + with mock.patch("os.listdir", mock.Mock(side_effect=fake_listdir)): + with self.assertRaises(PrepareDatabaseException): + # Synapse should think that there is an outstanding migration file due to + # patching 'os.listdir' in the function decorator. + # + # We expect Synapse to raise an exception to indicate the master process + # needs to apply this migration file. + prepare_database(db_conn, db_pool.engine, self.hs.config) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 2873e22ccf..fccab733c0 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -161,6 +161,54 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + def test__null_byte_in_display_name_properly_handled(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + + res = self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ) + # Check that we only got one result back + self.assertEqual(len(res), 1) + + # Check that alice's display name is "alice" + self.assertEqual(res[0]["display_name"], "alice") + + # Grab the event_id to use later + event_id = res[0]["event_id"] + + # Create a profile with the offending null byte in the display name + new_profile = {"displayname": "ali\u0000ce"} + + # Ensure that the change goes smoothly and does not fail due to the null byte + self.helper.change_membership( + room, + self.u_alice, + self.u_alice, + "join", + extra_data=new_profile, + tok=self.t_alice, + ) + + res2 = self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ) + # Check that we only have two results + self.assertEqual(len(res2), 2) + + # Filter out the previous event using the event_id we grabbed above + row = [row for row in res2 if row["event_id"] != event_id] + + # Check that alice's display name is now None + self.assertEqual(row[0]["display_name"], None) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 37cf7bb232..7f5b28aed8 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -23,6 +23,7 @@ from synapse.rest import admin from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.storage.background_updates import _BackgroundUpdateHandler from synapse.storage.roommember import ProfileInfo from synapse.util import Clock @@ -391,7 +392,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): with mock.patch.dict( self.store.db_pool.updates._background_update_handlers, - populate_user_directory_process_users=mocked_process_users, + populate_user_directory_process_users=_BackgroundUpdateHandler( + mocked_process_users, + ), ): self._purge_and_rebuild_user_dir() diff --git a/tests/test_federation.py b/tests/test_federation.py index 24fc77d7a7..3eef1c4c05 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -81,8 +81,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): origin, event, context, - state=None, - backfilled=False, ): return context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 94b19788d7..e0b08d67d4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,35 +13,30 @@ # limitations under the License. import logging from typing import Optional -from unittest.mock import Mock - -from twisted.internet import defer -from twisted.internet.defer import succeed from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import EventBase +from synapse.types import JsonDict from synapse.visibility import filter_events_for_server -import tests.unittest -from tests.utils import create_room, setup_test_homeserver +from tests import unittest +from tests.utils import create_room logger = logging.getLogger(__name__) TEST_ROOM_ID = "!TEST:ROOM" -class FilterEventsForServerTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class FilterEventsForServerTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.storage = self.hs.get_storage() - yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) + self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) - @defer.inlineCallbacks - def test_filtering(self): + def test_filtering(self) -> None: # # The events to be filtered consist of 10 membership events (it doesn't # really matter if they are joins or leaves, so let's make them joins). @@ -51,18 +46,20 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - yield self.inject_visibility("@admin:hs", "joined") + self.get_success(self._inject_visibility("@admin:hs", "joined")) for i in range(0, 10): - yield self.inject_room_member("@resident%i:hs" % i) + self.get_success(self._inject_room_member("@resident%i:hs" % i)) events_to_filter = [] for i in range(0, 10): user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = yield self.inject_room_member(user, extra_content={"a": "b"}) + evt = self.get_success( + self._inject_room_member(user, extra_content={"a": "b"}) + ) events_to_filter.append(evt) - filtered = yield defer.ensureDeferred( + filtered = self.get_success( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -75,34 +72,31 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - @defer.inlineCallbacks - def test_erased_user(self): + def test_erased_user(self) -> None: # 4 message events, from erased and unerased users, with a membership # change in the middle of them. events_to_filter = [] - evt = yield self.inject_message("@unerased:local_hs") + evt = self.get_success(self._inject_message("@unerased:local_hs")) events_to_filter.append(evt) - evt = yield self.inject_message("@erased:local_hs") + evt = self.get_success(self._inject_message("@erased:local_hs")) events_to_filter.append(evt) - evt = yield self.inject_room_member("@joiner:remote_hs") + evt = self.get_success(self._inject_room_member("@joiner:remote_hs")) events_to_filter.append(evt) - evt = yield self.inject_message("@unerased:local_hs") + evt = self.get_success(self._inject_message("@unerased:local_hs")) events_to_filter.append(evt) - evt = yield self.inject_message("@erased:local_hs") + evt = self.get_success(self._inject_message("@erased:local_hs")) events_to_filter.append(evt) # the erasey user gets erased - yield defer.ensureDeferred( - self.hs.get_datastore().mark_user_erased("@erased:local_hs") - ) + self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) # ... and the filtering happens. - filtered = yield defer.ensureDeferred( + filtered = self.get_success( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -123,8 +117,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): for i in (1, 4): self.assertNotIn("body", filtered[i].content) - @defer.inlineCallbacks - def inject_visibility(self, user_id, visibility): + def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: content = {"history_visibility": visibility} builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -137,18 +130,18 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event - @defer.inlineCallbacks - def inject_room_member( - self, user_id, membership="join", extra_content: Optional[dict] = None - ): + def _inject_room_member( + self, + user_id: str, + membership: str = "join", + extra_content: Optional[JsonDict] = None, + ) -> EventBase: content = {"membership": membership} content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( @@ -162,17 +155,16 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event - @defer.inlineCallbacks - def inject_message(self, user_id, content=None): + def _inject_message( + self, user_id: str, content: Optional[JsonDict] = None + ) -> EventBase: if content is None: content = {"body": "testytest", "msgtype": "m.text"} builder = self.event_builder_factory.for_room_version( @@ -185,164 +177,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event - - @defer.inlineCallbacks - def test_large_room(self): - # see what happens when we have a large room with hundreds of thousands - # of membership events - - # As above, the events to be filtered consist of 10 membership events, - # where one of them is for a user on the server we are filtering for. - - import cProfile - import pstats - import time - - # we stub out the store, because building up all that state the normal - # way is very slow. - test_store = _TestStore() - - # our initial state is 100000 membership events and one - # history_visibility event. - room_state = [] - - history_visibility_evt = FrozenEvent( - { - "event_id": "$history_vis", - "type": "m.room.history_visibility", - "sender": "@resident_user_0:test.com", - "state_key": "", - "room_id": TEST_ROOM_ID, - "content": {"history_visibility": "joined"}, - } - ) - room_state.append(history_visibility_evt) - test_store.add_event(history_visibility_evt) - - for i in range(0, 100000): - user = "@resident_user_%i:test.com" % (i,) - evt = FrozenEvent( - { - "event_id": "$res_event_%i" % (i,), - "type": "m.room.member", - "state_key": user, - "sender": user, - "room_id": TEST_ROOM_ID, - "content": {"membership": "join", "extra": "zzz,"}, - } - ) - room_state.append(evt) - test_store.add_event(evt) - - events_to_filter = [] - for i in range(0, 10): - user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = FrozenEvent( - { - "event_id": "$evt%i" % (i,), - "type": "m.room.member", - "state_key": user, - "sender": user, - "room_id": TEST_ROOM_ID, - "content": {"membership": "join", "extra": "zzz"}, - } - ) - events_to_filter.append(evt) - room_state.append(evt) - - test_store.add_event(evt) - test_store.set_state_ids_for_event( - evt, {(e.type, e.state_key): e.event_id for e in room_state} - ) - - pr = cProfile.Profile() - pr.enable() - - logger.info("Starting filtering") - start = time.time() - - storage = Mock() - storage.main = test_store - storage.state = test_store - - filtered = yield defer.ensureDeferred( - filter_events_for_server(test_store, "test_server", events_to_filter) - ) - logger.info("Filtering took %f seconds", time.time() - start) - - pr.disable() - with open("filter_events_for_server.profile", "w+") as f: - ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") - ps.print_stats() - - # the result should be 5 redacted events, and 5 unredacted events. - for i in range(0, 5): - self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) - self.assertNotIn("extra", filtered[i].content) - - for i in range(5, 10): - self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) - self.assertEqual(filtered[i].content["extra"], "zzz") - - test_large_room.skip = "Disabled by default because it's slow" - - -class _TestStore: - """Implements a few methods of the DataStore, so that we can test - filter_events_for_server - - """ - - def __init__(self): - # data for get_events: a map from event_id to event - self.events = {} - - # data for get_state_ids_for_events mock: a map from event_id to - # a map from (type_state_key) -> event_id for the state at that - # event - self.state_ids_for_events = {} - - def add_event(self, event): - self.events[event.event_id] = event - - def set_state_ids_for_event(self, event, state): - self.state_ids_for_events[event.event_id] = state - - def get_state_ids_for_events(self, events, types): - res = {} - include_memberships = False - for (type, state_key) in types: - if type == "m.room.history_visibility": - continue - if type != "m.room.member" or state_key is not None: - raise RuntimeError( - "Unimplemented: get_state_ids with type (%s, %s)" - % (type, state_key) - ) - include_memberships = True - - if include_memberships: - for event_id in events: - res[event_id] = self.state_ids_for_events[event_id] - - else: - k = ("m.room.history_visibility", "") - for event_id in events: - hve = self.state_ids_for_events[event_id][k] - res[event_id] = {k: hve} - - return succeed(res) - - def get_events(self, events): - return succeed({event_id: self.events[event_id] for event_id in events}) - - def are_users_erased(self, users): - return succeed({u: False for u in users}) diff --git a/tests/unittest.py b/tests/unittest.py index a9b60b7eeb..1431848367 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,20 @@ import inspect import logging import secrets import time -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, + AnyStr, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from unittest.mock import Mock, patch from canonicaljson import json @@ -31,6 +44,7 @@ from twisted.python.threadpool import ThreadPool from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from twisted.web.resource import Resource +from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership @@ -45,6 +59,7 @@ from synapse.logging.context import ( current_context, set_current_context, ) +from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock @@ -81,16 +96,13 @@ def around(target): return _around -T = TypeVar("T") - - class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the root logger's logging level while that test (case|method) runs.""" - def __init__(self, methodName, *args, **kwargs): - super().__init__(methodName, *args, **kwargs) + def __init__(self, methodName: str): + super().__init__(methodName) method = getattr(self, methodName) @@ -204,18 +216,18 @@ class HomeserverTestCase(TestCase): config dict. Attributes: - servlets (list[function]): List of servlet registration function. + servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. - hijack_auth (bool): Whether to hijack auth to return the user specified + hijack_auth: Whether to hijack auth to return the user specified in user_id. """ - servlets = [] - hijack_auth = True - needs_threadpool = False + hijack_auth: ClassVar[bool] = True + needs_threadpool: ClassVar[bool] = False + servlets: ClassVar[List[RegisterServletsFunc]] = [] - def __init__(self, methodName, *args, **kwargs): - super().__init__(methodName, *args, **kwargs) + def __init__(self, methodName: str): + super().__init__(methodName) # see if we have any additional config for this test method = getattr(self, methodName) @@ -287,9 +299,10 @@ class HomeserverTestCase(TestCase): None, ) - self.hs.get_auth().get_user_by_req = get_user_by_req - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token - self.hs.get_auth().get_access_token_from_request = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] return_value="1234" ) @@ -319,11 +332,12 @@ class HomeserverTestCase(TestCase): def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" + store = self.hs.get_datastore() while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() + store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 + store.db_pool.updates.do_next_background_update(False), by=0.1 ) def make_homeserver(self, reactor, clock): @@ -403,14 +417,12 @@ class HomeserverTestCase(TestCase): path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, + request: Type[Request] = SynapseRequest, shorthand: bool = True, - federation_auth_origin: str = None, + federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase): a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. - federation_auth_origin (bytes|None): if set to not-None, we will add a fake + federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. @@ -484,8 +496,7 @@ class HomeserverTestCase(TestCase): async def run_bg_updates(): with LoggingContext("run_bg_updates"): - while not await stor.db_pool.updates.has_completed_background_updates(): - await stor.db_pool.updates.do_next_background_update(1) + self.get_success(stor.db_pool.updates.run_background_updates(False)) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -584,7 +595,7 @@ class HomeserverTestCase(TestCase): nonce_str += b"\x00notadmin" want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) - want_mac = want_mac.hexdigest() + want_mac_digest = want_mac.hexdigest() body = json.dumps( { @@ -593,7 +604,7 @@ class HomeserverTestCase(TestCase): "displayname": displayname, "password": password, "admin": admin, - "mac": want_mac, + "mac": want_mac_digest, "inhibit_login": True, } ) @@ -639,9 +650,7 @@ class HomeserverTestCase(TestCase): username, password, device_id=None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): """ Log in a user, and get an access token. Requires the Login API be diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 6578f3411e..291644eb7d 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -13,6 +13,7 @@ # limitations under the License. +from typing import List from unittest.mock import Mock from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries @@ -261,6 +262,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): self.assertEquals(cache["key4"], [4]) self.assertEquals(cache["key5"], [5, 6]) + def test_zero_size_drop_from_cache(self) -> None: + """Test that `drop_from_cache` works correctly with 0-sized entries.""" + cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0) + cache["key1"] = [] + + self.assertEqual(len(cache), 0) + cache.cache["key1"].drop_from_cache() + self.assertIsNone( + cache.pop("key1"), "Cache entry should have been evicted but wasn't" + ) + class TimeEvictionTestCase(unittest.HomeserverTestCase): """Test that time based eviction works correctly.""" diff --git a/tests/utils.py b/tests/utils.py index cf8ba5c5db..983859120f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,7 +119,6 @@ def default_config(name, parse=False): "enable_registration": True, "enable_registration_captcha": False, "macaroon_secret_key": "not even a little secret", - "trusted_third_party_id_servers": [], "password_providers": [], "worker_replication_url": "", "worker_app": None, |