diff options
739 files changed, 26192 insertions, 15547 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml index 98c217dd1d..5bd2ab2b76 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,18 +4,16 @@ jobs: machine: true steps: - checkout - - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} -t matrixdotorg/synapse:${CIRCLE_TAG}-py3 . + - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} . - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - run: docker push matrixdotorg/synapse:${CIRCLE_TAG} - - run: docker push matrixdotorg/synapse:${CIRCLE_TAG}-py3 dockerhubuploadlatest: machine: true steps: - checkout - - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest -t matrixdotorg/synapse:latest-py3 . + - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest . - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - run: docker push matrixdotorg/synapse:latest - - run: docker push matrixdotorg/synapse:latest-py3 workflows: version: 2 diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md index 75c9b2c9fe..978b699886 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md @@ -4,12 +4,12 @@ about: Create a report to help us improve --- +<!-- + **THIS IS NOT A SUPPORT CHANNEL!** **IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**, please ask in **#synapse:matrix.org** (using a matrix.org account if necessary) -<!-- - If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/ This is a bug report template. By following the instructions below and diff --git a/CHANGES.md b/CHANGES.md index 9a30a2e901..dd4d110981 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,512 @@ +Synapse 1.20.0rc3 (2020-09-11) +============================== + +Bugfixes +-------- + +- Fix a bug introduced in v1.20.0rc1 where the wrong exception was raised when invalid JSON data is encountered. ([\#8291](https://github.com/matrix-org/synapse/issues/8291)) + + +Synapse 1.20.0rc2 (2020-09-09) +============================== + +Bugfixes +-------- + +- Fix a bug introduced in v1.20.0rc1 causing some features related to notifications to misbehave following the implementation of unread counts. ([\#8280](https://github.com/matrix-org/synapse/issues/8280)) + + +Synapse 1.20.0rc1 (2020-09-08) +============================== + +Removal warning +--------------- + +Some older clients used a [disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken) (`:`) in the `client_secret` parameter of various endpoints. The incorrect behaviour was allowed for backwards compatibility, but is now being removed from Synapse as most users have updated their client. Further context can be found at [\#6766](https://github.com/matrix-org/synapse/issues/6766). + +Features +-------- + +- Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666). ([\#7785](https://github.com/matrix-org/synapse/issues/7785)) +- Iteratively encode JSON to avoid blocking the reactor. ([\#8013](https://github.com/matrix-org/synapse/issues/8013), [\#8116](https://github.com/matrix-org/synapse/issues/8116)) +- Add support for shadow-banning users (ignoring any message send requests). ([\#8034](https://github.com/matrix-org/synapse/issues/8034), [\#8092](https://github.com/matrix-org/synapse/issues/8092), [\#8095](https://github.com/matrix-org/synapse/issues/8095), [\#8142](https://github.com/matrix-org/synapse/issues/8142), [\#8152](https://github.com/matrix-org/synapse/issues/8152), [\#8157](https://github.com/matrix-org/synapse/issues/8157), [\#8158](https://github.com/matrix-org/synapse/issues/8158), [\#8176](https://github.com/matrix-org/synapse/issues/8176)) +- Use the default template file when its equivalent is not found in a custom template directory. ([\#8037](https://github.com/matrix-org/synapse/issues/8037), [\#8107](https://github.com/matrix-org/synapse/issues/8107), [\#8252](https://github.com/matrix-org/synapse/issues/8252)) +- Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). ([\#8059](https://github.com/matrix-org/synapse/issues/8059), [\#8254](https://github.com/matrix-org/synapse/issues/8254), [\#8270](https://github.com/matrix-org/synapse/issues/8270), [\#8274](https://github.com/matrix-org/synapse/issues/8274)) +- Optimise `/federation/v1/user/devices/` API by only returning devices with encryption keys. ([\#8198](https://github.com/matrix-org/synapse/issues/8198)) + + +Bugfixes +-------- + +- Fix a memory leak by limiting the length of time that messages will be queued for a remote server that has been unreachable. ([\#7864](https://github.com/matrix-org/synapse/issues/7864)) +- Fix `Re-starting finished log context PUT-nnnn` warning when event persistence failed. ([\#8081](https://github.com/matrix-org/synapse/issues/8081)) +- Synapse now correctly enforces the valid characters in the `client_secret` parameter used in various endpoints. ([\#8101](https://github.com/matrix-org/synapse/issues/8101)) +- Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file. ([\#8104](https://github.com/matrix-org/synapse/issues/8104)) +- Fix a long-standing bug where invalid JSON would be accepted by Synapse. ([\#8106](https://github.com/matrix-org/synapse/issues/8106)) +- Fix a bug introduced in Synapse v1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite. ([\#8110](https://github.com/matrix-org/synapse/issues/8110)) +- Return a proper error code when the rooms of an invalid group are requested. ([\#8129](https://github.com/matrix-org/synapse/issues/8129)) +- Fix a bug which could cause a leaked postgres connection if synapse was set to daemonize. ([\#8131](https://github.com/matrix-org/synapse/issues/8131)) +- Clarify the error code if a user tries to register with a numeric ID. This bug was introduced in v1.15.0. ([\#8135](https://github.com/matrix-org/synapse/issues/8135)) +- Fix a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. ([\#8139](https://github.com/matrix-org/synapse/issues/8139)) +- Fix logging in via OpenID Connect with a provider that uses integer user IDs. ([\#8190](https://github.com/matrix-org/synapse/issues/8190)) +- Fix a longstanding bug where user directory updates could break when unexpected profile data was included in events. ([\#8223](https://github.com/matrix-org/synapse/issues/8223)) +- Fix a longstanding bug where stats updates could break when unexpected profile data was included in events. ([\#8226](https://github.com/matrix-org/synapse/issues/8226)) +- Fix slow start times for large servers by removing a table scan of the `users` table from startup code. ([\#8271](https://github.com/matrix-org/synapse/issues/8271)) + + +Updates to the Docker image +--------------------------- + +- Fix builds of the Docker image on non-x86 platforms. ([\#8144](https://github.com/matrix-org/synapse/issues/8144)) +- Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196. ([\#8147](https://github.com/matrix-org/synapse/issues/8147)) + + +Improved Documentation +---------------------- + +- Link to matrix-synapse-rest-password-provider in the password provider documentation. ([\#8111](https://github.com/matrix-org/synapse/issues/8111)) +- Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole. ([\#8120](https://github.com/matrix-org/synapse/issues/8120)) +- Explain better what GDPR-erased means when deactivating a user. ([\#8189](https://github.com/matrix-org/synapse/issues/8189)) + + +Internal Changes +---------------- + +- Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7377](https://github.com/matrix-org/synapse/issues/7377), [\#8163](https://github.com/matrix-org/synapse/issues/8163)) +- Reduce run times of some unit tests by advancing the reactor a fewer number of times. ([\#7757](https://github.com/matrix-org/synapse/issues/7757)) +- Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on. ([\#7991](https://github.com/matrix-org/synapse/issues/7991)) +- Convert various parts of the codebase to async/await. ([\#8071](https://github.com/matrix-org/synapse/issues/8071), [\#8072](https://github.com/matrix-org/synapse/issues/8072), [\#8074](https://github.com/matrix-org/synapse/issues/8074), [\#8075](https://github.com/matrix-org/synapse/issues/8075), [\#8076](https://github.com/matrix-org/synapse/issues/8076), [\#8087](https://github.com/matrix-org/synapse/issues/8087), [\#8100](https://github.com/matrix-org/synapse/issues/8100), [\#8119](https://github.com/matrix-org/synapse/issues/8119), [\#8121](https://github.com/matrix-org/synapse/issues/8121), [\#8133](https://github.com/matrix-org/synapse/issues/8133), [\#8156](https://github.com/matrix-org/synapse/issues/8156), [\#8162](https://github.com/matrix-org/synapse/issues/8162), [\#8166](https://github.com/matrix-org/synapse/issues/8166), [\#8168](https://github.com/matrix-org/synapse/issues/8168), [\#8173](https://github.com/matrix-org/synapse/issues/8173), [\#8191](https://github.com/matrix-org/synapse/issues/8191), [\#8192](https://github.com/matrix-org/synapse/issues/8192), [\#8193](https://github.com/matrix-org/synapse/issues/8193), [\#8194](https://github.com/matrix-org/synapse/issues/8194), [\#8195](https://github.com/matrix-org/synapse/issues/8195), [\#8197](https://github.com/matrix-org/synapse/issues/8197), [\#8199](https://github.com/matrix-org/synapse/issues/8199), [\#8200](https://github.com/matrix-org/synapse/issues/8200), [\#8201](https://github.com/matrix-org/synapse/issues/8201), [\#8202](https://github.com/matrix-org/synapse/issues/8202), [\#8207](https://github.com/matrix-org/synapse/issues/8207), [\#8213](https://github.com/matrix-org/synapse/issues/8213), [\#8214](https://github.com/matrix-org/synapse/issues/8214)) +- Remove some unused database functions. ([\#8085](https://github.com/matrix-org/synapse/issues/8085)) +- Add type hints to various parts of the codebase. ([\#8090](https://github.com/matrix-org/synapse/issues/8090), [\#8127](https://github.com/matrix-org/synapse/issues/8127), [\#8187](https://github.com/matrix-org/synapse/issues/8187), [\#8241](https://github.com/matrix-org/synapse/issues/8241), [\#8140](https://github.com/matrix-org/synapse/issues/8140), [\#8183](https://github.com/matrix-org/synapse/issues/8183), [\#8232](https://github.com/matrix-org/synapse/issues/8232), [\#8235](https://github.com/matrix-org/synapse/issues/8235), [\#8237](https://github.com/matrix-org/synapse/issues/8237), [\#8244](https://github.com/matrix-org/synapse/issues/8244)) +- Return the previous stream token if a non-member event is a duplicate. ([\#8093](https://github.com/matrix-org/synapse/issues/8093), [\#8112](https://github.com/matrix-org/synapse/issues/8112)) +- Separate `get_current_token` into two since there are two different use cases for it. ([\#8113](https://github.com/matrix-org/synapse/issues/8113)) +- Remove `ChainedIdGenerator`. ([\#8123](https://github.com/matrix-org/synapse/issues/8123)) +- Reduce the amount of whitespace in JSON stored and sent in responses. ([\#8124](https://github.com/matrix-org/synapse/issues/8124)) +- Update the test federation client to handle streaming responses. ([\#8130](https://github.com/matrix-org/synapse/issues/8130)) +- Micro-optimisations to `get_auth_chain_ids`. ([\#8132](https://github.com/matrix-org/synapse/issues/8132)) +- Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface. ([\#8161](https://github.com/matrix-org/synapse/issues/8161)) +- Add functions to `MultiWriterIdGen` used by events stream. ([\#8164](https://github.com/matrix-org/synapse/issues/8164), [\#8179](https://github.com/matrix-org/synapse/issues/8179)) +- Fix tests that were broken due to the merge of 1.19.1. ([\#8167](https://github.com/matrix-org/synapse/issues/8167)) +- Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`. ([\#8171](https://github.com/matrix-org/synapse/issues/8171)) +- Remove unused `is_guest` parameter from, and add safeguard to, `MessageHandler.get_room_data`. ([\#8174](https://github.com/matrix-org/synapse/issues/8174), [\#8181](https://github.com/matrix-org/synapse/issues/8181)) +- Standardize the mypy configuration. ([\#8175](https://github.com/matrix-org/synapse/issues/8175)) +- Refactor some of `LoginRestServlet`'s helper methods, and move them to `AuthHandler` for easier reuse. ([\#8182](https://github.com/matrix-org/synapse/issues/8182)) +- Fix `wait_for_stream_position` to allow multiple waiters on same stream ID. ([\#8196](https://github.com/matrix-org/synapse/issues/8196)) +- Make `MultiWriterIDGenerator` work for streams that use negative values. ([\#8203](https://github.com/matrix-org/synapse/issues/8203)) +- Refactor queries for device keys and cross-signatures. ([\#8204](https://github.com/matrix-org/synapse/issues/8204), [\#8205](https://github.com/matrix-org/synapse/issues/8205), [\#8222](https://github.com/matrix-org/synapse/issues/8222), [\#8224](https://github.com/matrix-org/synapse/issues/8224), [\#8225](https://github.com/matrix-org/synapse/issues/8225), [\#8231](https://github.com/matrix-org/synapse/issues/8231), [\#8233](https://github.com/matrix-org/synapse/issues/8233), [\#8234](https://github.com/matrix-org/synapse/issues/8234)) +- Fix type hints for functions decorated with `@cached`. ([\#8240](https://github.com/matrix-org/synapse/issues/8240)) +- Remove obsolete `order` field from federation send queues. ([\#8245](https://github.com/matrix-org/synapse/issues/8245)) +- Stop sub-classing from object. ([\#8249](https://github.com/matrix-org/synapse/issues/8249)) +- Add more logging to debug slow startup. ([\#8264](https://github.com/matrix-org/synapse/issues/8264)) +- Do not attempt to upgrade database schema on worker processes. ([\#8266](https://github.com/matrix-org/synapse/issues/8266), [\#8276](https://github.com/matrix-org/synapse/issues/8276)) + + +Synapse 1.19.1 (2020-08-27) +=========================== + +No significant changes. + + +Synapse 1.19.1rc1 (2020-08-25) +============================== + +Bugfixes +-------- + +- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139)) +- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153)) + + +Synapse 1.19.0 (2020-08-17) +=========================== + +No significant changes since 1.19.0rc1. + +Removal warning +--------------- + +As outlined in the [previous release](https://github.com/matrix-org/synapse/releases/tag/v1.18.0), we are no longer publishing Docker images with the `-py3` tag suffix. On top of that, we have also removed the `latest-py3` tag. Please see [the announcement in the upgrade notes for 1.18.0](https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180). + + +Synapse 1.19.0rc1 (2020-08-13) +============================== + +Features +-------- + +- Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus. ([\#7902](https://github.com/matrix-org/synapse/issues/7902)) +- Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7964](https://github.com/matrix-org/synapse/issues/7964)) +- Add rate limiting to users joining rooms. ([\#8008](https://github.com/matrix-org/synapse/issues/8008)) +- Add a `/health` endpoint to every configured HTTP listener that can be used as a health check endpoint by load balancers. ([\#8048](https://github.com/matrix-org/synapse/issues/8048)) +- Allow login to be blocked based on the values of SAML attributes. ([\#8052](https://github.com/matrix-org/synapse/issues/8052)) +- Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7314](https://github.com/matrix-org/synapse/issues/7314)) + + +Bugfixes +-------- + +- Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory. ([\#7977](https://github.com/matrix-org/synapse/issues/7977)) +- Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured. ([\#7978](https://github.com/matrix-org/synapse/issues/7978)) +- Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0. ([\#7980](https://github.com/matrix-org/synapse/issues/7980)) +- Fix various comments and minor discrepencies in server notices code. ([\#7996](https://github.com/matrix-org/synapse/issues/7996)) +- Fix a long standing bug where HTTP HEAD requests resulted in a 400 error. ([\#7999](https://github.com/matrix-org/synapse/issues/7999)) +- Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger. ([\#8011](https://github.com/matrix-org/synapse/issues/8011), [\#8012](https://github.com/matrix-org/synapse/issues/8012)) + + +Updates to the Docker image +--------------------------- + +- We no longer publish Docker images with the `-py3` tag suffix, as [announced in the upgrade notes](https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180). ([\#8056](https://github.com/matrix-org/synapse/issues/8056)) + + +Improved Documentation +---------------------- + +- Document how to set up a client .well-known file and fix several pieces of outdated documentation. ([\#7899](https://github.com/matrix-org/synapse/issues/7899)) +- Improve workers docs. ([\#7990](https://github.com/matrix-org/synapse/issues/7990), [\#8000](https://github.com/matrix-org/synapse/issues/8000)) +- Fix typo in `docs/workers.md`. ([\#7992](https://github.com/matrix-org/synapse/issues/7992)) +- Add documentation for how to undo a room shutdown. ([\#7998](https://github.com/matrix-org/synapse/issues/7998), [\#8010](https://github.com/matrix-org/synapse/issues/8010)) + + +Internal Changes +---------------- + +- Reduce the amount of whitespace in JSON stored and sent in responses. Contributed by David Vo. ([\#7372](https://github.com/matrix-org/synapse/issues/7372)) +- Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. ([\#7936](https://github.com/matrix-org/synapse/issues/7936), [\#7979](https://github.com/matrix-org/synapse/issues/7979)) +- Convert various parts of the codebase to async/await. ([\#7947](https://github.com/matrix-org/synapse/issues/7947), [\#7948](https://github.com/matrix-org/synapse/issues/7948), [\#7949](https://github.com/matrix-org/synapse/issues/7949), [\#7951](https://github.com/matrix-org/synapse/issues/7951), [\#7963](https://github.com/matrix-org/synapse/issues/7963), [\#7973](https://github.com/matrix-org/synapse/issues/7973), [\#7975](https://github.com/matrix-org/synapse/issues/7975), [\#7976](https://github.com/matrix-org/synapse/issues/7976), [\#7981](https://github.com/matrix-org/synapse/issues/7981), [\#7987](https://github.com/matrix-org/synapse/issues/7987), [\#7989](https://github.com/matrix-org/synapse/issues/7989), [\#8003](https://github.com/matrix-org/synapse/issues/8003), [\#8014](https://github.com/matrix-org/synapse/issues/8014), [\#8016](https://github.com/matrix-org/synapse/issues/8016), [\#8027](https://github.com/matrix-org/synapse/issues/8027), [\#8031](https://github.com/matrix-org/synapse/issues/8031), [\#8032](https://github.com/matrix-org/synapse/issues/8032), [\#8035](https://github.com/matrix-org/synapse/issues/8035), [\#8042](https://github.com/matrix-org/synapse/issues/8042), [\#8044](https://github.com/matrix-org/synapse/issues/8044), [\#8045](https://github.com/matrix-org/synapse/issues/8045), [\#8061](https://github.com/matrix-org/synapse/issues/8061), [\#8062](https://github.com/matrix-org/synapse/issues/8062), [\#8063](https://github.com/matrix-org/synapse/issues/8063), [\#8066](https://github.com/matrix-org/synapse/issues/8066), [\#8069](https://github.com/matrix-org/synapse/issues/8069), [\#8070](https://github.com/matrix-org/synapse/issues/8070)) +- Move some database-related log lines from the default logger to the database/transaction loggers. ([\#7952](https://github.com/matrix-org/synapse/issues/7952)) +- Add a script to detect source code files using non-unix line terminators. ([\#7965](https://github.com/matrix-org/synapse/issues/7965), [\#7970](https://github.com/matrix-org/synapse/issues/7970)) +- Log the SAML session ID during creation. ([\#7971](https://github.com/matrix-org/synapse/issues/7971)) +- Implement new experimental push rules for some users. ([\#7997](https://github.com/matrix-org/synapse/issues/7997)) +- Remove redundant and unreliable signature check for v1 Identity Service lookup responses. ([\#8001](https://github.com/matrix-org/synapse/issues/8001)) +- Improve the performance of the register endpoint. ([\#8009](https://github.com/matrix-org/synapse/issues/8009)) +- Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error. ([\#8024](https://github.com/matrix-org/synapse/issues/8024)) +- Rename storage layer objects to be more sensible. ([\#8033](https://github.com/matrix-org/synapse/issues/8033)) +- Change the default log config to reduce disk I/O and storage for new servers. ([\#8040](https://github.com/matrix-org/synapse/issues/8040)) +- Add an assertion on `prev_events` in `create_new_client_event`. ([\#8041](https://github.com/matrix-org/synapse/issues/8041)) +- Add a comment to `ServerContextFactory` about the use of `SSLv23_METHOD`. ([\#8043](https://github.com/matrix-org/synapse/issues/8043)) +- Log `OPTIONS` requests at `DEBUG` rather than `INFO` level to reduce amount logged at `INFO`. ([\#8049](https://github.com/matrix-org/synapse/issues/8049)) +- Reduce amount of outbound request logging at `INFO` level. ([\#8050](https://github.com/matrix-org/synapse/issues/8050)) +- It is no longer necessary to explicitly define `filters` in the logging configuration. (Continuing to do so is redundant but harmless.) ([\#8051](https://github.com/matrix-org/synapse/issues/8051)) +- Add and improve type hints. ([\#8058](https://github.com/matrix-org/synapse/issues/8058), [\#8064](https://github.com/matrix-org/synapse/issues/8064), [\#8060](https://github.com/matrix-org/synapse/issues/8060), [\#8067](https://github.com/matrix-org/synapse/issues/8067)) + + +Synapse 1.18.0 (2020-07-30) +=========================== + +Deprecation Warnings +-------------------- + +### Docker Tags with `-py3` Suffix + +From 10th August 2020, we will no longer publish Docker images with the `-py3` tag suffix. The images tagged with the `-py3` suffix have been identical to the non-suffixed tags since release 0.99.0, and the suffix is obsolete. + +On 10th August, we will remove the `latest-py3` tag. Existing per-release tags (such as `v1.18.0-py3`) will not be removed, but no new `-py3` tags will be added. + +Scripts relying on the `-py3` suffix will need to be updated. + + +### TCP-based Replication + +When setting up worker processes, we now recommend the use of a Redis server for replication. The old direct TCP connection method is deprecated and will be removed in a future release. See [docs/workers.md](https://github.com/matrix-org/synapse/blob/release-v1.18.0/docs/workers.md) for more details. + + +Improved Documentation +---------------------- + +- Update worker docs with latest enhancements. ([\#7969](https://github.com/matrix-org/synapse/issues/7969)) + + +Synapse 1.18.0rc2 (2020-07-28) +============================== + +Bugfixes +-------- + +- Fix an `AssertionError` exception introduced in v1.18.0rc1. ([\#7876](https://github.com/matrix-org/synapse/issues/7876)) +- Fix experimental support for moving typing off master when worker is restarted, which is broken in v1.18.0rc1. ([\#7967](https://github.com/matrix-org/synapse/issues/7967)) + + +Internal Changes +---------------- + +- Further optimise queueing of inbound replication commands. ([\#7876](https://github.com/matrix-org/synapse/issues/7876)) + + +Synapse 1.18.0rc1 (2020-07-27) +============================== + +Features +-------- + +- Include room states on invite events that are sent to application services. Contributed by @Sorunome. ([\#6455](https://github.com/matrix-org/synapse/issues/6455)) +- Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7613](https://github.com/matrix-org/synapse/issues/7613), [\#7953](https://github.com/matrix-org/synapse/issues/7953)) +- Add experimental support for running multiple federation sender processes. ([\#7798](https://github.com/matrix-org/synapse/issues/7798)) +- Add the option to validate the `iss` and `aud` claims for JWT logins. ([\#7827](https://github.com/matrix-org/synapse/issues/7827)) +- Add support for handling registration requests across multiple client reader workers. ([\#7830](https://github.com/matrix-org/synapse/issues/7830)) +- Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7842](https://github.com/matrix-org/synapse/issues/7842)) +- Allow email subjects to be customised through Synapse's configuration. ([\#7846](https://github.com/matrix-org/synapse/issues/7846)) +- Add the ability to re-activate an account from the admin API. ([\#7847](https://github.com/matrix-org/synapse/issues/7847), [\#7908](https://github.com/matrix-org/synapse/issues/7908)) +- Add experimental support for running multiple pusher workers. ([\#7855](https://github.com/matrix-org/synapse/issues/7855)) +- Add experimental support for moving typing off master. ([\#7869](https://github.com/matrix-org/synapse/issues/7869), [\#7959](https://github.com/matrix-org/synapse/issues/7959)) +- Report CPU metrics to prometheus for time spent processing replication commands. ([\#7879](https://github.com/matrix-org/synapse/issues/7879)) +- Support oEmbed for media previews. ([\#7920](https://github.com/matrix-org/synapse/issues/7920)) +- Abort federation requests where the client disconnects before the ratelimiter expires. ([\#7930](https://github.com/matrix-org/synapse/issues/7930)) +- Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work. ([\#7931](https://github.com/matrix-org/synapse/issues/7931)) + + +Bugfixes +-------- + +- Fix detection of out of sync remote device lists when receiving events from remote users. ([\#7815](https://github.com/matrix-org/synapse/issues/7815)) +- Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain. ([\#7817](https://github.com/matrix-org/synapse/issues/7817)) +- Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0. ([\#7822](https://github.com/matrix-org/synapse/issues/7822)) +- Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. ([\#7829](https://github.com/matrix-org/synapse/issues/7829)) +- Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. ([\#7844](https://github.com/matrix-org/synapse/issues/7844)) +- Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0. ([\#7850](https://github.com/matrix-org/synapse/issues/7850)) +- Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. ([\#7854](https://github.com/matrix-org/synapse/issues/7854)) +- Fix a bug which allowed empty rooms to be rejoined over federation. ([\#7859](https://github.com/matrix-org/synapse/issues/7859)) +- Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. ([\#7866](https://github.com/matrix-org/synapse/issues/7866)) +- Fix a long standing bug where the tracing of async functions with opentracing was broken. ([\#7872](https://github.com/matrix-org/synapse/issues/7872), [\#7961](https://github.com/matrix-org/synapse/issues/7961)) +- Fix "TypeError in `synapse.notifier`" exceptions. ([\#7880](https://github.com/matrix-org/synapse/issues/7880)) +- Fix deprecation warning due to invalid escape sequences. ([\#7895](https://github.com/matrix-org/synapse/issues/7895)) + + +Updates to the Docker image +--------------------------- + +- Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. ([\#7839](https://github.com/matrix-org/synapse/issues/7839)) + + +Improved Documentation +---------------------- + +- Provide instructions on using `register_new_matrix_user` via docker. ([\#7885](https://github.com/matrix-org/synapse/issues/7885)) +- Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. ([\#7889](https://github.com/matrix-org/synapse/issues/7889)) +- Reorder database paragraphs to promote postgres over sqlite. ([\#7933](https://github.com/matrix-org/synapse/issues/7933)) +- Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md). ([\#7934](https://github.com/matrix-org/synapse/issues/7934)) + + +Deprecations and Removals +------------------------- + +- Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. ([\#7878](https://github.com/matrix-org/synapse/issues/7878)) +- Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. ([\#7888](https://github.com/matrix-org/synapse/issues/7888)) + + +Internal Changes +---------------- + +- Switch parts of the codebase from `simplejson` to the standard library `json`. ([\#7802](https://github.com/matrix-org/synapse/issues/7802)) +- Add type hints to the http server code and remove an unused parameter. ([\#7813](https://github.com/matrix-org/synapse/issues/7813)) +- Add type hints to synapse.api.errors module. ([\#7820](https://github.com/matrix-org/synapse/issues/7820)) +- Ensure that calls to `json.dumps` are compatible with the standard library json. ([\#7836](https://github.com/matrix-org/synapse/issues/7836)) +- Remove redundant `retry_on_integrity_error` wrapper for event persistence code. ([\#7848](https://github.com/matrix-org/synapse/issues/7848)) +- Consistently use `db_to_json` to convert from database values to JSON objects. ([\#7849](https://github.com/matrix-org/synapse/issues/7849)) +- Convert various parts of the codebase to async/await. ([\#7851](https://github.com/matrix-org/synapse/issues/7851), [\#7860](https://github.com/matrix-org/synapse/issues/7860), [\#7868](https://github.com/matrix-org/synapse/issues/7868), [\#7871](https://github.com/matrix-org/synapse/issues/7871), [\#7873](https://github.com/matrix-org/synapse/issues/7873), [\#7874](https://github.com/matrix-org/synapse/issues/7874), [\#7884](https://github.com/matrix-org/synapse/issues/7884), [\#7912](https://github.com/matrix-org/synapse/issues/7912), [\#7935](https://github.com/matrix-org/synapse/issues/7935), [\#7939](https://github.com/matrix-org/synapse/issues/7939), [\#7942](https://github.com/matrix-org/synapse/issues/7942), [\#7944](https://github.com/matrix-org/synapse/issues/7944)) +- Add support for handling registration requests across multiple client reader workers. ([\#7853](https://github.com/matrix-org/synapse/issues/7853)) +- Small performance improvement in typing processing. ([\#7856](https://github.com/matrix-org/synapse/issues/7856)) +- The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. ([\#7858](https://github.com/matrix-org/synapse/issues/7858)) +- Optimise queueing of inbound replication commands. ([\#7861](https://github.com/matrix-org/synapse/issues/7861)) +- Add some type annotations to `HomeServer` and `BaseHandler`. ([\#7870](https://github.com/matrix-org/synapse/issues/7870)) +- Clean up `PreserveLoggingContext`. ([\#7877](https://github.com/matrix-org/synapse/issues/7877)) +- Change "unknown room version" logging from 'error' to 'warning'. ([\#7881](https://github.com/matrix-org/synapse/issues/7881)) +- Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. ([\#7882](https://github.com/matrix-org/synapse/issues/7882)) +- Return an empty body for OPTIONS requests. ([\#7886](https://github.com/matrix-org/synapse/issues/7886)) +- Fix typo in generated config file. Contributed by @ThiefMaster. ([\#7890](https://github.com/matrix-org/synapse/issues/7890)) +- Import ABC from `collections.abc` for Python 3.10 compatibility. ([\#7892](https://github.com/matrix-org/synapse/issues/7892)) +- Remove unused functions `time_function`, `trace_function`, `get_previous_frames` + and `get_previous_frame` from `synapse.logging.utils` module. ([\#7897](https://github.com/matrix-org/synapse/issues/7897)) +- Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. ([\#7914](https://github.com/matrix-org/synapse/issues/7914)) +- Use Element CSS and logo in notification emails when app name is Element. ([\#7919](https://github.com/matrix-org/synapse/issues/7919)) +- Optimisation to /sync handling: skip serializing the response if the client has already disconnected. ([\#7927](https://github.com/matrix-org/synapse/issues/7927)) +- When a client disconnects, don't log it as 'Error processing request'. ([\#7928](https://github.com/matrix-org/synapse/issues/7928)) +- Add debugging to `/sync` response generation (disabled by default). ([\#7929](https://github.com/matrix-org/synapse/issues/7929)) +- Update comments that refer to Deferreds for async functions. ([\#7945](https://github.com/matrix-org/synapse/issues/7945)) +- Simplify error handling in federation handler. ([\#7950](https://github.com/matrix-org/synapse/issues/7950)) + + +Synapse 1.17.0 (2020-07-13) +=========================== + +Synapse 1.17.0 is identical to 1.17.0rc1, with the addition of the fix that was included in 1.16.1. + + +Synapse 1.16.1 (2020-07-10) +=========================== + +In some distributions of Synapse 1.16.0, we incorrectly included a database migration which added a new, unused table. This release removes the redundant table. + +Bugfixes +-------- + +- Drop table `local_rejections_stream` which was incorrectly added in Synapse 1.16.0. ([\#7816](https://github.com/matrix-org/synapse/issues/7816), [b1beb3ff5](https://github.com/matrix-org/synapse/commit/b1beb3ff5)) + + +Synapse 1.17.0rc1 (2020-07-09) +============================== + +Bugfixes +-------- + +- Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. ([\#7021](https://github.com/matrix-org/synapse/issues/7021)) +- Fix "Tried to close a non-active scope!" error messages when opentracing is enabled. ([\#7732](https://github.com/matrix-org/synapse/issues/7732)) +- Fix incorrect error message when database CTYPE was set incorrectly. ([\#7760](https://github.com/matrix-org/synapse/issues/7760)) +- Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. ([\#7766](https://github.com/matrix-org/synapse/issues/7766)) +- Fix synctl to handle empty config files correctly. Contributed by @kotovalexarian. ([\#7779](https://github.com/matrix-org/synapse/issues/7779)) +- Fixes a long standing bug in worker mode where worker information was saved in the devices table instead of the original IP address and user agent. ([\#7797](https://github.com/matrix-org/synapse/issues/7797)) +- Fix 'stuck invites' which happen when we are unable to reject a room invite received over federation. ([\#7804](https://github.com/matrix-org/synapse/issues/7804), [\#7809](https://github.com/matrix-org/synapse/issues/7809), [\#7810](https://github.com/matrix-org/synapse/issues/7810)) + + +Updates to the Docker image +--------------------------- + +- Include libwebp in the Docker file to properly handle webp image uploads. ([\#7791](https://github.com/matrix-org/synapse/issues/7791)) + + +Improved Documentation +---------------------- + +- Improve the documentation of the non-standard JSON web token login type. ([\#7776](https://github.com/matrix-org/synapse/issues/7776)) +- Update doc links for caddy. Contributed by Nicolai Søborg. ([\#7789](https://github.com/matrix-org/synapse/issues/7789)) + + +Internal Changes +---------------- + +- Refactor getting replication updates from database. ([\#7740](https://github.com/matrix-org/synapse/issues/7740)) +- Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. ([\#7765](https://github.com/matrix-org/synapse/issues/7765)) +- Use symbolic names for replication stream names. ([\#7768](https://github.com/matrix-org/synapse/issues/7768)) +- Add early returns to `_check_for_soft_fail`. ([\#7769](https://github.com/matrix-org/synapse/issues/7769)) +- Fix up `synapse.handlers.federation` to pass mypy. ([\#7770](https://github.com/matrix-org/synapse/issues/7770)) +- Convert the appserver handler to async/await. ([\#7775](https://github.com/matrix-org/synapse/issues/7775)) +- Allow to use higher versions of prometheus_client <0.9.0 which are expected to introduce no breaking changes. Contributed by Oliver Kurz. ([\#7780](https://github.com/matrix-org/synapse/issues/7780)) +- Update linting scripts and codebase to be compatible with `isort` v5. ([\#7786](https://github.com/matrix-org/synapse/issues/7786)) +- Stop populating unused table `local_invites`. ([\#7793](https://github.com/matrix-org/synapse/issues/7793)) +- Ensure that strings (not bytes) are passed into JSON serialization. ([\#7799](https://github.com/matrix-org/synapse/issues/7799)) +- Switch from simplejson to the standard library json. ([\#7800](https://github.com/matrix-org/synapse/issues/7800)) +- Add `signing_key` property to `HomeServer` to save code duplication. ([\#7805](https://github.com/matrix-org/synapse/issues/7805)) +- Improve stacktraces from exceptions in background processes. ([\#7808](https://github.com/matrix-org/synapse/issues/7808)) +- Fix various spelling errors in comments and log lines. ([\#7811](https://github.com/matrix-org/synapse/issues/7811)) + + +Synapse 1.16.0 (2020-07-08) +=========================== + +No significant changes since 1.16.0rc2. + +Note that this release deprecates the `m.login.jwt` login method, renaming it +to `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec. +Otherwise the behaviour is identical. Synapse will accept both names for now, +but this may change in a future release. + +Synapse 1.16.0rc2 (2020-07-02) +============================== + +Synapse 1.16.0rc2 includes the security fixes released with Synapse 1.15.2. +Please see [below](#synapse-1152-2020-07-02) for more details. + +Improved Documentation +---------------------- + +- Update postgres image in example `docker-compose.yaml` to tag `12-alpine`. ([\#7696](https://github.com/matrix-org/synapse/issues/7696)) + + +Internal Changes +---------------- + +- Add some metrics for inbound and outbound federation latencies: `synapse_federation_server_pdu_process_time` and `synapse_event_processing_lag_by_event`. ([\#7771](https://github.com/matrix-org/synapse/issues/7771)) + + +Synapse 1.15.2 (2020-07-02) +=========================== + +Due to the two security issues highlighted below, server administrators are +encouraged to update Synapse. We are not aware of these vulnerabilities being +exploited in the wild. + +Security advisory +----------------- + +* A malicious homeserver could force Synapse to reset the state in a room to a + small subset of the correct state. This affects all Synapse deployments which + federate with untrusted servers. ([96e9afe6](https://github.com/matrix-org/synapse/commit/96e9afe62500310977dc3cbc99a8d16d3d2fa15c)) +* HTML pages served via Synapse were vulnerable to clickjacking attacks. This + predominantly affects homeservers with single-sign-on enabled, but all server + administrators are encouraged to upgrade. ([ea26e9a9](https://github.com/matrix-org/synapse/commit/ea26e9a98b0541fc886a1cb826a38352b7599dbe)) + + This was reported by [Quentin Gliech](https://sandhose.fr/). + + +Synapse 1.16.0rc1 (2020-07-01) +============================== + +Features +-------- + +- Add an option to enable encryption by default for new rooms. ([\#7639](https://github.com/matrix-org/synapse/issues/7639)) +- Add support for running multiple media repository workers. See [docs/workers.md](https://github.com/matrix-org/synapse/blob/release-v1.16.0/docs/workers.md) for instructions. ([\#7706](https://github.com/matrix-org/synapse/issues/7706)) +- Media can now be marked as safe from quarantined. ([\#7718](https://github.com/matrix-org/synapse/issues/7718)) +- Expand the configuration options for auto-join rooms. ([\#7763](https://github.com/matrix-org/synapse/issues/7763)) + + +Bugfixes +-------- + +- Remove `user_id` from the response to `GET /_matrix/client/r0/presence/{userId}/status` to match the specification. ([\#7606](https://github.com/matrix-org/synapse/issues/7606)) +- In worker mode, ensure that replicated data has not already been received. ([\#7648](https://github.com/matrix-org/synapse/issues/7648)) +- Fix intermittent exception during startup, introduced in Synapse 1.14.0. ([\#7663](https://github.com/matrix-org/synapse/issues/7663)) +- Include a user-agent for federation and well-known requests. ([\#7677](https://github.com/matrix-org/synapse/issues/7677)) +- Accept the proper field (`phone`) for the `m.id.phone` identifier type. The legacy field of `number` is still accepted as a fallback. Bug introduced in v0.20.0. ([\#7687](https://github.com/matrix-org/synapse/issues/7687)) +- Fix "Starting db txn 'get_completed_ui_auth_stages' from sentinel context" warning. The bug was introduced in 1.13.0. ([\#7688](https://github.com/matrix-org/synapse/issues/7688)) +- Compare the URI and method during user interactive authentication (instead of the URI twice). Bug introduced in 1.13.0. ([\#7689](https://github.com/matrix-org/synapse/issues/7689)) +- Fix a long standing bug where the response to the `GET room_keys/version` endpoint had the incorrect type for the `etag` field. ([\#7691](https://github.com/matrix-org/synapse/issues/7691)) +- Fix logged error during device resync in opentracing. Broke in v1.14.0. ([\#7698](https://github.com/matrix-org/synapse/issues/7698)) +- Do not break push rule evaluation when receiving an event with a non-string body. This is a long-standing bug. ([\#7701](https://github.com/matrix-org/synapse/issues/7701)) +- Fixs a long standing bug which resulted in an exception: "TypeError: argument of type 'ObservableDeferred' is not iterable". ([\#7708](https://github.com/matrix-org/synapse/issues/7708)) +- The `synapse_port_db` script no longer fails when the `ui_auth_sessions` table is non-empty. This bug has existed since v1.13.0. ([\#7711](https://github.com/matrix-org/synapse/issues/7711)) +- Synapse will now fetch media from the proper specified URL (using the r0 prefix instead of the unspecified v1). ([\#7714](https://github.com/matrix-org/synapse/issues/7714)) +- Fix the tables ignored by `synapse_port_db` to be in sync the current database schema. ([\#7717](https://github.com/matrix-org/synapse/issues/7717)) +- Fix missing `Content-Length` on HTTP responses from the metrics handler. ([\#7730](https://github.com/matrix-org/synapse/issues/7730)) +- Fix large state resolutions from stalling Synapse for seconds at a time. ([\#7735](https://github.com/matrix-org/synapse/issues/7735), [\#7746](https://github.com/matrix-org/synapse/issues/7746)) + + +Improved Documentation +---------------------- + +- Spelling correction in sample_config.yaml. ([\#7652](https://github.com/matrix-org/synapse/issues/7652)) +- Added instructions for how to use Keycloak via OpenID Connect to authenticate with Synapse. ([\#7659](https://github.com/matrix-org/synapse/issues/7659)) +- Corrected misspelling of PostgreSQL. ([\#7724](https://github.com/matrix-org/synapse/issues/7724)) + + +Deprecations and Removals +------------------------- + +- Deprecate `m.login.jwt` login method in favour of `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec. ([\#7675](https://github.com/matrix-org/synapse/issues/7675)) + + +Internal Changes +---------------- + +- Refactor getting replication updates from database. ([\#7636](https://github.com/matrix-org/synapse/issues/7636)) +- Clean-up the login fallback code. ([\#7657](https://github.com/matrix-org/synapse/issues/7657)) +- Increase the default SAML session expiry time to 15 minutes. ([\#7664](https://github.com/matrix-org/synapse/issues/7664)) +- Convert the device message and pagination handlers to async/await. ([\#7678](https://github.com/matrix-org/synapse/issues/7678)) +- Convert typing handler to async/await. ([\#7679](https://github.com/matrix-org/synapse/issues/7679)) +- Require `parameterized` package version to be at least 0.7.0. ([\#7680](https://github.com/matrix-org/synapse/issues/7680)) +- Refactor handling of `listeners` configuration settings. ([\#7681](https://github.com/matrix-org/synapse/issues/7681)) +- Replace uses of `six.iterkeys`/`iteritems`/`itervalues` with `keys()`/`items()`/`values()`. ([\#7692](https://github.com/matrix-org/synapse/issues/7692)) +- Add support for using `rust-python-jaeger-reporter` library to reduce jaeger tracing overhead. ([\#7697](https://github.com/matrix-org/synapse/issues/7697)) +- Make Tox actions work on Debian 10. ([\#7703](https://github.com/matrix-org/synapse/issues/7703)) +- Replace all remaining uses of `six` with native Python 3 equivalents. Contributed by @ilmari. ([\#7704](https://github.com/matrix-org/synapse/issues/7704)) +- Fix broken link in sample config. ([\#7712](https://github.com/matrix-org/synapse/issues/7712)) +- Speed up state res v2 across large state differences. ([\#7725](https://github.com/matrix-org/synapse/issues/7725)) +- Convert directory handler to async/await. ([\#7727](https://github.com/matrix-org/synapse/issues/7727)) +- Move `flake8` to the end of `scripts-dev/lint.sh` as it takes the longest and could cause the script to exit early. ([\#7738](https://github.com/matrix-org/synapse/issues/7738)) +- Explain the "test" conditional requirement for dependencies is not all of the modules necessary to run the unit tests. ([\#7751](https://github.com/matrix-org/synapse/issues/7751)) +- Add some metrics for inbound and outbound federation latencies: `synapse_federation_server_pdu_process_time` and `synapse_event_processing_lag_by_event`. ([\#7755](https://github.com/matrix-org/synapse/issues/7755)) + + Synapse 1.15.1 (2020-06-16) =========================== diff --git a/INSTALL.md b/INSTALL.md index ef80a26c3f..22f7b7c029 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,10 +1,12 @@ - [Choosing your server name](#choosing-your-server-name) +- [Picking a database engine](#picking-a-database-engine) - [Installing Synapse](#installing-synapse) - [Installing from source](#installing-from-source) - [Platform-Specific Instructions](#platform-specific-instructions) - [Prebuilt packages](#prebuilt-packages) - [Setting up Synapse](#setting-up-synapse) - [TLS certificates](#tls-certificates) + - [Client Well-Known URI](#client-well-known-uri) - [Email](#email) - [Registering a user](#registering-a-user) - [Setting up a TURN server](#setting-up-a-turn-server) @@ -27,6 +29,25 @@ that your email address is probably `user@example.com` rather than `user@email.example.com`) - but doing so may require more advanced setup: see [Setting up Federation](docs/federate.md). +# Picking a database engine + +Synapse offers two database engines: + * [PostgreSQL](https://www.postgresql.org) + * [SQLite](https://sqlite.org/) + +Almost all installations should opt to use PostgreSQL. Advantages include: + +* significant performance improvements due to the superior threading and + caching model, smarter query optimiser +* allowing the DB to be run on separate hardware + +For information on how to install and use PostgreSQL, please see +[docs/postgres.md](docs/postgres.md) + +By default Synapse uses SQLite and in doing so trades performance for convenience. +SQLite is only recommended in Synapse for testing purposes or for servers with +light workloads. + # Installing Synapse ## Installing from source @@ -234,9 +255,9 @@ for a number of platforms. There is an offical synapse image available at https://hub.docker.com/r/matrixdotorg/synapse which can be used with -the docker-compose file available at [contrib/docker](contrib/docker). Further information on -this including configuration options is available in the README on -hub.docker.com. +the docker-compose file available at [contrib/docker](contrib/docker). Further +information on this including configuration options is available in the README +on hub.docker.com. Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Dockerfile to automate a synapse server in a single Docker image, at @@ -244,7 +265,8 @@ https://hub.docker.com/r/avhost/docker-matrix/tags/ Slavi Pantaleev has created an Ansible playbook, which installs the offical Docker image of Matrix Synapse -along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.). +along with many other Matrix-related services (Postgres database, Element, coturn, +ma1sd, SSL support, etc.). For more details, see https://github.com/spantaleev/matrix-docker-ansible-deploy @@ -277,22 +299,27 @@ The fingerprint of the repository signing key (as shown by `gpg /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. -#### Downstream Debian/Ubuntu packages +#### Downstream Debian packages -For `buster` and `sid`, Synapse is available in the Debian repositories and -it should be possible to install it with simply: +We do not recommend using the packages from the default Debian `buster` +repository at this time, as they are old and suffer from known security +vulnerabilities. You can install the latest version of Synapse from +[our repository](#matrixorg-packages) or from `buster-backports`. Please +see the [Debian documentation](https://backports.debian.org/Instructions/) +for information on how to use backports. + +If you are using Debian `sid` or testing, Synapse is available in the default +repositories and it should be possible to install it simply with: ``` sudo apt install matrix-synapse ``` -There is also a version of `matrix-synapse` in `stretch-backports`. Please see -the [Debian documentation on -backports](https://backports.debian.org/Instructions/) for information on how -to use them. +#### Downstream Ubuntu packages -We do not recommend using the packages in downstream Ubuntu at this time, as -they are old and suffer from known security vulnerabilities. +We do not recommend using the packages in the default Ubuntu repository +at this time, as they are old and suffer from known security vulnerabilities. +The latest version of Synapse can be installed from [our repository](#matrixorg-packages). ### Fedora @@ -405,13 +432,11 @@ so, you will need to edit `homeserver.yaml`, as follows: ``` * You will also need to uncomment the `tls_certificate_path` and - `tls_private_key_path` lines under the `TLS` section. You can either - point these settings at an existing certificate and key, or you can - enable Synapse's built-in ACME (Let's Encrypt) support. Instructions - for having Synapse automatically provision and renew federation - certificates through ACME can be found at [ACME.md](docs/ACME.md). - Note that, as pointed out in that document, this feature will not - work with installs set up after November 2019. + `tls_private_key_path` lines under the `TLS` section. You will need to manage + provisioning of these certificates yourself — Synapse had built-in ACME + support, but the ACMEv1 protocol Synapse implements is deprecated, not + allowed by LetsEncrypt for new sites, and will break for existing sites in + late 2020. See [ACME.md](docs/ACME.md). If you are using your own certificate, be sure to use a `.pem` file that includes the full certificate chain including any intermediate certificates @@ -421,6 +446,60 @@ so, you will need to edit `homeserver.yaml`, as follows: For a more detailed guide to configuring your server for federation, see [federate.md](docs/federate.md). +## Client Well-Known URI + +Setting up the client Well-Known URI is optional but if you set it up, it will +allow users to enter their full username (e.g. `@user:<server_name>`) into clients +which support well-known lookup to automatically configure the homeserver and +identity server URLs. This is useful so that users don't have to memorize or think +about the actual homeserver URL you are using. + +The URL `https://<server_name>/.well-known/matrix/client` should return JSON in +the following format. + +``` +{ + "m.homeserver": { + "base_url": "https://<matrix.example.com>" + } +} +``` + +It can optionally contain identity server information as well. + +``` +{ + "m.homeserver": { + "base_url": "https://<matrix.example.com>" + }, + "m.identity_server": { + "base_url": "https://<identity.example.com>" + } +} +``` + +To work in browser based clients, the file must be served with the appropriate +Cross-Origin Resource Sharing (CORS) headers. A recommended value would be +`Access-Control-Allow-Origin: *` which would allow all browser based clients to +view it. + +In nginx this would be something like: +``` +location /.well-known/matrix/client { + return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}'; + add_header Content-Type application/json; + add_header Access-Control-Allow-Origin *; +} +``` + +You should also ensure the `public_baseurl` option in `homeserver.yaml` is set +correctly. `public_baseurl` should be set to the URL that clients will use to +connect to your server. This is the same URL you put for the `m.homeserver` +`base_url` above. + +``` +public_baseurl: "https://<matrix.example.com>" +``` ## Email @@ -439,7 +518,7 @@ email will be disabled. ## Registering a user -The easiest way to create a new user is to do so from a client like [Riot](https://riot.im). +The easiest way to create a new user is to do so from a client like [Element](https://element.io/). Alternatively you can do so from the command line if you have installed via pip. diff --git a/README.rst b/README.rst index 31d375d19b..4a189c8bc4 100644 --- a/README.rst +++ b/README.rst @@ -45,7 +45,7 @@ which handle: - Eventually-consistent cryptographically secure synchronisation of room state across a global open network of federated servers and services - Sending and receiving extensible messages in a room with (optional) - end-to-end encryption[1] + end-to-end encryption - Inviting, joining, leaving, kicking, banning room members - Managing user accounts (registration, login, logout) - Using 3rd Party IDs (3PIDs) such as email addresses, phone numbers, @@ -82,9 +82,6 @@ at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the Thanks for using Matrix! -[1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_. - - Support ======= @@ -115,12 +112,11 @@ Unless you are running a test instance of Synapse on your local machine, in general, you will need to enable TLS support before you can successfully connect from a client: see `<INSTALL.md#tls-certificates>`_. -An easy way to get started is to login or register via Riot at -https://riot.im/app/#/login or https://riot.im/app/#/register respectively. +An easy way to get started is to login or register via Element at +https://app.element.io/#/login or https://app.element.io/#/register respectively. You will need to change the server you are logging into from ``matrix.org`` and instead specify a Homeserver URL of ``https://<server_name>:8448`` (or just ``https://<server_name>`` if you are using a reverse proxy). -(Leave the identity server as the default - see `Identity servers`_.) If you prefer to use another client, refer to our `client breakdown <https://matrix.org/docs/projects/clients-matrix>`_. @@ -137,7 +133,7 @@ it, specify ``enable_registration: true`` in ``homeserver.yaml``. (It is then recommended to also set up CAPTCHA - see `<docs/CAPTCHA_SETUP.md>`_.) Once ``enable_registration`` is set to ``true``, it is possible to register a -user via `riot.im <https://riot.im/app/#/register>`_ or other Matrix clients. +user via a Matrix client. Your new user name will be formed partly from the ``server_name``, and partly from a localpart you specify when you create the account. Your name will take @@ -183,30 +179,6 @@ versions of synapse. .. _UPGRADE.rst: UPGRADE.rst - -Using PostgreSQL -================ - -Synapse offers two database engines: - * `SQLite <https://sqlite.org/>`_ - * `PostgreSQL <https://www.postgresql.org>`_ - -By default Synapse uses SQLite in and doing so trades performance for convenience. -SQLite is only recommended in Synapse for testing purposes or for servers with -light workloads. - -Almost all installations should opt to use PostreSQL. Advantages include: - -* significant performance improvements due to the superior threading and - caching model, smarter query optimiser -* allowing the DB to be run on separate hardware -* allowing basic active/backup high-availability with a "hot spare" synapse - pointing at the same DB master, as well as enabling DB replication in - synapse itself. - -For information on how to install and use PostgreSQL, please see -`docs/postgres.md <docs/postgres.md>`_. - .. _reverse-proxy: Using a reverse proxy with Synapse @@ -215,7 +187,7 @@ Using a reverse proxy with Synapse It is recommended to put a reverse proxy such as `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_, `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_, -`Caddy <https://caddyserver.com/docs/proxy>`_ or +`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root privileges. @@ -255,10 +227,9 @@ email address. Password reset ============== -If a user has registered an email address to their account using an identity -server, they can request a password-reset token via clients such as Riot. - -A manual password reset can be done via direct database access as follows. +Users can reset their password through their client. Alternatively, a server admin +can reset a users password using the `admin API <docs/admin_api/user_admin_api.rst#reset-password>`_ +or by directly editing the database as shown below. First calculate the hash of the new password:: diff --git a/UPGRADE.rst b/UPGRADE.rst index 3b5627e852..6492fa011f 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,24 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.18.0 +==================== + +Docker `-py3` suffix will be removed in future versions +------------------------------------------------------- + +From 10th August 2020, we will no longer publish Docker images with the `-py3` tag suffix. The images tagged with the `-py3` suffix have been identical to the non-suffixed tags since release 0.99.0, and the suffix is obsolete. + +On 10th August, we will remove the `latest-py3` tag. Existing per-release tags (such as `v1.18.0-py3`) will not be removed, but no new `-py3` tags will be added. + +Scripts relying on the `-py3` suffix will need to be updated. + +Redis replication is now recommended in lieu of TCP replication +--------------------------------------------------------------- + +When setting up worker processes, we now recommend the use of a Redis server for replication. **The old direct TCP connection method is deprecated and will be removed in a future release.** +See `docs/workers.md <docs/workers.md>`_ for more details. + Upgrading to v1.14.0 ==================== diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 48da410d94..dfc1d294dc 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -17,9 +17,6 @@ """ Starts a synapse client console. """ from __future__ import print_function -from twisted.internet import reactor, defer, threads -from http import TwistedHttpClient - import argparse import cmd import getpass @@ -28,12 +25,14 @@ import shlex import sys import time import urllib -import urlparse +from http import TwistedHttpClient -import nacl.signing import nacl.encoding +import nacl.signing +import urlparse +from signedjson.sign import SignatureVerifyException, verify_signed_json -from signedjson.sign import verify_signed_json, SignatureVerifyException +from twisted.internet import defer, reactor, threads CONFIG_JSON = "cmdclient_config.json" @@ -493,7 +492,7 @@ class SynapseCmd(cmd.Cmd): "list messages <roomid> from=END&to=START&limit=3" """ args = self._parse(line, ["type", "roomid", "qp"]) - if not "type" in args or not "roomid" in args: + if "type" not in args or "roomid" not in args: print("Must specify type and room ID.") return if args["type"] not in ["members", "messages"]: @@ -508,7 +507,7 @@ class SynapseCmd(cmd.Cmd): try: key_value = key_value_str.split("=") qp[key_value[0]] = key_value[1] - except: + except Exception: print("Bad query param: %s" % key_value) return @@ -585,7 +584,7 @@ class SynapseCmd(cmd.Cmd): parsed_url = urlparse.urlparse(args["path"]) qp.update(urlparse.parse_qs(parsed_url.query)) args["path"] = parsed_url.path - except: + except Exception: pass reactor.callFromThread( @@ -610,13 +609,15 @@ class SynapseCmd(cmd.Cmd): @defer.inlineCallbacks def _do_event_stream(self, timeout): - res = yield self.http_client.get_json( - self._url() + "/events", - { - "access_token": self._tok(), - "timeout": str(timeout), - "from": self.event_stream_token, - }, + res = yield defer.ensureDeferred( + self.http_client.get_json( + self._url() + "/events", + { + "access_token": self._tok(), + "timeout": str(timeout), + "from": self.event_stream_token, + }, + ) ) print(json.dumps(res, indent=4)) @@ -772,10 +773,10 @@ def main(server_url, identity_server_url, username, token, config_path): syn_cmd.config = json.load(config) try: http_client.verbose = "on" == syn_cmd.config["verbose"] - except: + except Exception: pass print("Loaded config from %s" % config_path) - except: + except Exception: pass # Twisted-specific: Runs the command processor in Twisted's event loop diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 0e101d2be5..cd3260b27d 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -14,17 +14,17 @@ # limitations under the License. from __future__ import print_function -from twisted.web.client import Agent, readBody -from twisted.web.http_headers import Headers -from twisted.internet import defer, reactor - -from pprint import pformat import json import urllib +from pprint import pformat + +from twisted.internet import defer, reactor +from twisted.web.client import Agent, readBody +from twisted.web.http_headers import Headers -class HttpClient(object): +class HttpClient: """ Interface for talking json over http """ @@ -169,7 +169,7 @@ class TwistedHttpClient(HttpClient): return d -class _RawProducer(object): +class _RawProducer: def __init__(self, data): self.data = data self.body = data @@ -186,7 +186,7 @@ class _RawProducer(object): pass -class _JsonProducer(object): +class _JsonProducer: """ Used by the twisted http client to create the HTTP body from json """ diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml index 17354b6610..d1ecd453db 100644 --- a/contrib/docker/docker-compose.yml +++ b/contrib/docker/docker-compose.yml @@ -50,7 +50,7 @@ services: - traefik.http.routers.https-synapse.tls.certResolver=le-ssl db: - image: docker.io/postgres:10-alpine + image: docker.io/postgres:12-alpine # Change that password, of course! environment: - POSTGRES_USER=synapse diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py index ffefe3bb39..15a22c3a0e 100644 --- a/contrib/experiments/cursesio.py +++ b/contrib/experiments/cursesio.py @@ -141,7 +141,7 @@ class CursesStdIO: curses.endwin() -class Callback(object): +class Callback: def __init__(self, stdio): self.stdio = stdio diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 3bbbcfa1b4..d4c35ff2fc 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -28,27 +28,24 @@ Currently assumes the local address is localhost:<port> """ -from synapse.federation import ReplicationHandler - -from synapse.federation.units import Pdu - -from synapse.util import origin_from_ucid - -from synapse.app.homeserver import SynapseHomeServer - -# from synapse.logging.utils import log_function - -from twisted.internet import reactor, defer -from twisted.python import log - import argparse +import curses.wrapper import json import logging import os import re import cursesio -import curses.wrapper + +from twisted.internet import defer, reactor +from twisted.python import log + +from synapse.app.homeserver import SynapseHomeServer +from synapse.federation import ReplicationHandler +from synapse.federation.units import Pdu +from synapse.util import origin_from_ucid + +# from synapse.logging.utils import log_function logger = logging.getLogger("example") @@ -58,7 +55,7 @@ def excpetion_errback(failure): logging.exception(failure) -class InputOutput(object): +class InputOutput: """ This is responsible for basic I/O so that a user can interact with the example app. """ @@ -75,7 +72,7 @@ class InputOutput(object): """ try: - m = re.match("^join (\S+)$", line) + m = re.match(r"^join (\S+)$", line) if m: # The `sender` wants to join a room. (room_name,) = m.groups() @@ -84,7 +81,7 @@ class InputOutput(object): # self.print_line("OK.") return - m = re.match("^invite (\S+) (\S+)$", line) + m = re.match(r"^invite (\S+) (\S+)$", line) if m: # `sender` wants to invite someone to a room room_name, invitee = m.groups() @@ -93,7 +90,7 @@ class InputOutput(object): # self.print_line("OK.") return - m = re.match("^send (\S+) (.*)$", line) + m = re.match(r"^send (\S+) (.*)$", line) if m: # `sender` wants to message a room room_name, body = m.groups() @@ -102,7 +99,7 @@ class InputOutput(object): # self.print_line("OK.") return - m = re.match("^backfill (\S+)$", line) + m = re.match(r"^backfill (\S+)$", line) if m: # we want to backfill a room (room_name,) = m.groups() @@ -135,7 +132,7 @@ class IOLoggerHandler(logging.Handler): self.io.print_log(msg) -class Room(object): +class Room: """ Used to store (in memory) the current membership state of a room, and which home servers we should send PDUs associated with the room to. """ @@ -201,16 +198,6 @@ class HomeServer(ReplicationHandler): % (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) ) - # def on_state_change(self, pdu): - ##self.output.print_line("#%s (state) %s *** %s" % - ##(pdu.context, pdu.state_key, pdu.pdu_type) - ##) - - # if "joinee" in pdu.content: - # self._on_join(pdu.context, pdu.content["joinee"]) - # elif "invitee" in pdu.content: - # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) - def _on_message(self, pdu): """ We received a message """ @@ -314,7 +301,7 @@ class HomeServer(ReplicationHandler): return self.replication_layer.backfill(dest, room_name, limit) def _get_room_remote_servers(self, room_name): - return [i for i in self.joined_rooms.setdefault(room_name).servers] + return list(self.joined_rooms.setdefault(room_name).servers) def _get_or_create_room(self, room_name): return self.joined_rooms.setdefault(room_name, Room(room_name)) @@ -334,7 +321,7 @@ def main(stdscr): user = args.user server_name = origin_from_ucid(user) - ## Set up logging ## + # Set up logging root_logger = logging.getLogger() @@ -354,7 +341,7 @@ def main(stdscr): observer = log.PythonLoggingObserver() observer.start() - ## Set up synapse server + # Set up synapse server curses_stdio = cursesio.CursesStdIO(stdscr) input_output = InputOutput(curses_stdio, user) @@ -368,16 +355,16 @@ def main(stdscr): input_output.set_home_server(hs) - ## Add input_output logger + # Add input_output logger io_logger = IOLoggerHandler(input_output) io_logger.setFormatter(formatter) root_logger.addHandler(io_logger) - ## Start! ## + # Start! try: port = int(server_name.split(":")[1]) - except: + except Exception: port = 12345 app_hs.get_http_server().start_listening(port) diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index 30a8681f5a..539569b5b1 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -1,7 +1,44 @@ { + "__inputs": [ + { + "name": "DS_PROMETHEUS", + "label": "Prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "6.7.4" + }, + { + "type": "panel", + "id": "graph", + "name": "Graph", + "version": "" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + } + ], "annotations": { "list": [ { + "$$hashKey": "object:76", "builtIn": 1, "datasource": "$datasource", "enable": false, @@ -17,8 +54,8 @@ "editable": true, "gnetId": null, "graphTooltip": 0, - "id": 1, - "iteration": 1591098104645, + "id": null, + "iteration": 1594646317221, "links": [ { "asDropdown": true, @@ -34,7 +71,7 @@ "panels": [ { "collapsed": false, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -269,7 +306,6 @@ "show": false }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -559,7 +595,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -1423,7 +1459,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -1795,7 +1831,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -2531,7 +2567,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -2823,7 +2859,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -2844,7 +2880,7 @@ "h": 9, "w": 12, "x": 0, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 79, @@ -2940,7 +2976,7 @@ "h": 9, "w": 12, "x": 12, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 83, @@ -3038,7 +3074,7 @@ "h": 9, "w": 12, "x": 0, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 109, @@ -3137,7 +3173,7 @@ "h": 9, "w": 12, "x": 12, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 111, @@ -3223,14 +3259,14 @@ "dashLength": 10, "dashes": false, "datasource": "$datasource", - "description": "", + "description": "Number of events queued up on the master process for processing by the federation sender", "fill": 1, "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, - "y": 51 + "y": 24 }, "hiddenSeries": false, "id": 140, @@ -3354,6 +3390,103 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "description": "The number of events in the in-memory queues ", + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "hiddenSeries": false, + "id": 142, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "dataLinks": [] + }, + "percentage": false, + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_transaction_queue_pending_pdus{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "interval": "", + "legendFormat": "pending PDUs {{job}}-{{index}}", + "refId": "A" + }, + { + "expr": "synapse_federation_transaction_queue_pending_edus{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "interval": "", + "legendFormat": "pending EDUs {{job}}-{{index}}", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "In-memory federation transmission queues", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "$$hashKey": "object:317", + "format": "short", + "label": "events", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "$$hashKey": "object:318", + "format": "short", + "label": "", + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "title": "Federation", @@ -3361,7 +3494,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -3567,7 +3700,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -3588,7 +3721,7 @@ "h": 7, "w": 12, "x": 0, - "y": 52 + "y": 79 }, "hiddenSeries": false, "id": 48, @@ -3682,7 +3815,7 @@ "h": 7, "w": 12, "x": 12, - "y": 52 + "y": 79 }, "hiddenSeries": false, "id": 104, @@ -3802,7 +3935,7 @@ "h": 7, "w": 12, "x": 0, - "y": 59 + "y": 86 }, "hiddenSeries": false, "id": 10, @@ -3898,7 +4031,7 @@ "h": 7, "w": 12, "x": 12, - "y": 59 + "y": 86 }, "hiddenSeries": false, "id": 11, @@ -3987,7 +4120,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -4011,7 +4144,7 @@ "h": 13, "w": 12, "x": 0, - "y": 67 + "y": 80 }, "hiddenSeries": false, "id": 12, @@ -4106,7 +4239,7 @@ "h": 13, "w": 12, "x": 12, - "y": 67 + "y": 80 }, "hiddenSeries": false, "id": 26, @@ -4201,7 +4334,7 @@ "h": 13, "w": 12, "x": 0, - "y": 80 + "y": 93 }, "hiddenSeries": false, "id": 13, @@ -4297,7 +4430,7 @@ "h": 13, "w": 12, "x": 12, - "y": 80 + "y": 93 }, "hiddenSeries": false, "id": 27, @@ -4392,7 +4525,7 @@ "h": 13, "w": 12, "x": 0, - "y": 93 + "y": 106 }, "hiddenSeries": false, "id": 28, @@ -4486,7 +4619,7 @@ "h": 13, "w": 12, "x": 12, - "y": 93 + "y": 106 }, "hiddenSeries": false, "id": 25, @@ -4572,7 +4705,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -5062,7 +5195,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -5083,7 +5216,7 @@ "h": 9, "w": 12, "x": 0, - "y": 66 + "y": 121 }, "hiddenSeries": false, "id": 91, @@ -5179,7 +5312,7 @@ "h": 9, "w": 12, "x": 12, - "y": 66 + "y": 121 }, "hiddenSeries": false, "id": 21, @@ -5271,7 +5404,7 @@ "h": 9, "w": 12, "x": 0, - "y": 75 + "y": 130 }, "hiddenSeries": false, "id": 89, @@ -5369,7 +5502,7 @@ "h": 9, "w": 12, "x": 12, - "y": 75 + "y": 130 }, "hiddenSeries": false, "id": 93, @@ -5459,7 +5592,7 @@ "h": 9, "w": 12, "x": 0, - "y": 84 + "y": 139 }, "hiddenSeries": false, "id": 95, @@ -5552,12 +5685,12 @@ "mode": "spectrum" }, "dataFormat": "tsbuckets", - "datasource": "Prometheus", + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 9, "w": 12, "x": 12, - "y": 84 + "y": 139 }, "heatmap": {}, "hideZeroBuckets": true, @@ -5567,7 +5700,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -5609,7 +5741,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -5630,7 +5762,7 @@ "h": 7, "w": 12, "x": 0, - "y": 39 + "y": 66 }, "hiddenSeries": false, "id": 2, @@ -5754,7 +5886,7 @@ "h": 7, "w": 12, "x": 12, - "y": 39 + "y": 66 }, "hiddenSeries": false, "id": 41, @@ -5847,7 +5979,7 @@ "h": 7, "w": 12, "x": 0, - "y": 46 + "y": 73 }, "hiddenSeries": false, "id": 42, @@ -5939,7 +6071,7 @@ "h": 7, "w": 12, "x": 12, - "y": 46 + "y": 73 }, "hiddenSeries": false, "id": 43, @@ -6031,7 +6163,7 @@ "h": 7, "w": 12, "x": 0, - "y": 53 + "y": 80 }, "hiddenSeries": false, "id": 113, @@ -6129,7 +6261,7 @@ "h": 7, "w": 12, "x": 12, - "y": 53 + "y": 80 }, "hiddenSeries": false, "id": 115, @@ -6215,7 +6347,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -6236,7 +6368,7 @@ "h": 9, "w": 12, "x": 0, - "y": 58 + "y": 40 }, "hiddenSeries": false, "id": 67, @@ -6267,7 +6399,7 @@ "steppedLine": false, "targets": [ { - "expr": " synapse_event_persisted_position{instance=\"$instance\",job=\"synapse\"} - ignoring(index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "expr": "max(synapse_event_persisted_position{instance=\"$instance\"}) - ignoring(instance,index, job, name) group_right() synapse_event_processing_positions{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -6328,7 +6460,7 @@ "h": 9, "w": 12, "x": 12, - "y": 58 + "y": 40 }, "hiddenSeries": false, "id": 71, @@ -6362,6 +6494,7 @@ "expr": "time()*1000-synapse_event_processing_last_ts{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", "format": "time_series", "hide": false, + "interval": "", "intervalFactor": 1, "legendFormat": "{{job}}-{{index}} {{name}}", "refId": "B" @@ -6420,7 +6553,7 @@ "h": 9, "w": 12, "x": 0, - "y": 67 + "y": 49 }, "hiddenSeries": false, "id": 121, @@ -6509,7 +6642,7 @@ }, { "collapsed": true, - "datasource": null, + "datasource": "${DS_PROMETHEUS}", "gridPos": { "h": 1, "w": 24, @@ -6539,7 +6672,7 @@ "h": 8, "w": 12, "x": 0, - "y": 41 + "y": 86 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6549,7 +6682,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -6599,7 +6731,7 @@ "h": 8, "w": 12, "x": 12, - "y": 41 + "y": 86 }, "hiddenSeries": false, "id": 124, @@ -6700,7 +6832,7 @@ "h": 8, "w": 12, "x": 0, - "y": 49 + "y": 94 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6710,7 +6842,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -6760,7 +6891,7 @@ "h": 8, "w": 12, "x": 12, - "y": 49 + "y": 94 }, "hiddenSeries": false, "id": 128, @@ -6879,7 +7010,7 @@ "h": 8, "w": 12, "x": 0, - "y": 57 + "y": 102 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6889,7 +7020,6 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { @@ -6939,7 +7069,7 @@ "h": 8, "w": 12, "x": 12, - "y": 57 + "y": 102 }, "hiddenSeries": false, "id": 130, @@ -7058,7 +7188,7 @@ "h": 8, "w": 12, "x": 0, - "y": 65 + "y": 110 }, "heatmap": {}, "hideZeroBuckets": true, @@ -7068,12 +7198,12 @@ "show": true }, "links": [], - "options": {}, "reverseYBuckets": false, "targets": [ { - "expr": "rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0)", + "expr": "rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "heatmap", + "interval": "", "intervalFactor": 1, "legendFormat": "{{le}}", "refId": "A" @@ -7118,7 +7248,7 @@ "h": 8, "w": 12, "x": 12, - "y": 65 + "y": 110 }, "hiddenSeries": false, "id": 132, @@ -7149,29 +7279,33 @@ "steppedLine": false, "targets": [ { - "expr": "histogram_quantile(0.5, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0)) ", + "expr": "histogram_quantile(0.5, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "50%", "refId": "A" }, { - "expr": "histogram_quantile(0.75, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))", + "expr": "histogram_quantile(0.75, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "75%", "refId": "B" }, { - "expr": "histogram_quantile(0.90, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))", + "expr": "histogram_quantile(0.90, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "90%", "refId": "C" }, { - "expr": "histogram_quantile(0.99, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\"}[$bucket_size]) and on (index, instance, job) (synapse_storage_events_persisted_events > 0))", + "expr": "histogram_quantile(0.99, rate(synapse_state_number_state_groups_in_resolution_bucket{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 1, "legendFormat": "99%", "refId": "D" @@ -7181,7 +7315,7 @@ "timeFrom": null, "timeRegions": [], "timeShift": null, - "title": "Number of state resolution performed, by number of state groups involved (quantiles)", + "title": "Number of state resolutions performed, by number of state groups involved (quantiles)", "tooltip": { "shared": true, "sort": 0, @@ -7233,6 +7367,7 @@ "list": [ { "current": { + "selected": false, "text": "Prometheus", "value": "Prometheus" }, @@ -7309,14 +7444,12 @@ }, { "allValue": null, - "current": { - "text": "matrix.org", - "value": "matrix.org" - }, + "current": {}, "datasource": "$datasource", "definition": "", "hide": 0, "includeAll": false, + "index": -1, "label": null, "multi": false, "name": "instance", @@ -7335,17 +7468,13 @@ { "allFormat": "regex wildcard", "allValue": "", - "current": { - "text": "synapse", - "value": [ - "synapse" - ] - }, + "current": {}, "datasource": "$datasource", "definition": "", "hide": 0, "hideLabel": false, "includeAll": true, + "index": -1, "label": "Job", "multi": true, "multiFormat": "regex values", @@ -7366,16 +7495,13 @@ { "allFormat": "regex wildcard", "allValue": ".*", - "current": { - "selected": false, - "text": "All", - "value": "$__all" - }, + "current": {}, "datasource": "$datasource", "definition": "", "hide": 0, "hideLabel": false, "includeAll": true, + "index": -1, "label": "", "multi": true, "multiFormat": "regex values", @@ -7428,5 +7554,8 @@ "timezone": "", "title": "Synapse", "uid": "000000012", - "version": 29 + "variables": { + "list": [] + }, + "version": 32 } \ No newline at end of file diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py index 92736480eb..de33fac1c7 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py @@ -1,5 +1,13 @@ from __future__ import print_function +import argparse +import cgi +import datetime +import json + +import pydot +import urllib2 + # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,15 +23,6 @@ from __future__ import print_function # limitations under the License. -import sqlite3 -import pydot -import cgi -import json -import datetime -import argparse -import urllib2 - - def make_name(pdu_id, origin): return "%s@%s" % (pdu_id, origin) @@ -33,7 +32,7 @@ def make_graph(pdus, room, filename_prefix): node_map = {} origins = set() - colors = set(("red", "green", "blue", "yellow", "purple")) + colors = {"red", "green", "blue", "yellow", "purple"} for pdu in pdus: origins.add(pdu.get("origin")) @@ -49,7 +48,7 @@ def make_graph(pdus, room, filename_prefix): try: c = colors.pop() color_map[o] = c - except: + except Exception: print("Run out of colours!") color_map[o] = "black" diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py index 4619f0e3c1..0980231e4a 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py @@ -13,12 +13,13 @@ # limitations under the License. -import sqlite3 -import pydot +import argparse import cgi -import json import datetime -import argparse +import json +import sqlite3 + +import pydot from synapse.events import FrozenEvent from synapse.util.frozenutils import unfreeze @@ -98,7 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit): for prev_id, _ in event.prev_events: try: end_node = node_map[prev_id] - except: + except Exception: end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,)) node_map[prev_id] = end_node diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index 7f9e5374a6..91db98e7ef 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -1,5 +1,15 @@ from __future__ import print_function +import argparse +import cgi +import datetime + +import pydot +import simplejson as json + +from synapse.events import FrozenEvent +from synapse.util.frozenutils import unfreeze + # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,18 +25,6 @@ from __future__ import print_function # limitations under the License. -import pydot -import cgi -import simplejson as json -import datetime -import argparse - -from synapse.events import FrozenEvent -from synapse.util.frozenutils import unfreeze - -from six import string_types - - def make_graph(file_name, room_id, file_prefix, limit): print("Reading lines") with open(file_name) as f: @@ -62,7 +60,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for key, value in unfreeze(event.get_dict()["content"]).items(): if value is None: value = "<null>" - elif isinstance(value, string_types): + elif isinstance(value, str): pass else: value = json.dumps(value) @@ -108,7 +106,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for prev_id, _ in event.prev_events: try: end_node = node_map[prev_id] - except: + except Exception: end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,)) node_map[prev_id] = end_node diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index 67fb2cd1a7..69aa74bd34 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -12,15 +12,15 @@ npm install jquery jsdom """ from __future__ import print_function -import gevent -import grequests -from BeautifulSoup import BeautifulSoup import json -import urllib import subprocess import time -# ACCESS_TOKEN="" # +import gevent +import grequests +from BeautifulSoup import BeautifulSoup + +ACCESS_TOKEN = "" MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/" MYUSERNAME = "@davetest:matrix.org" diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py index f57e6e7d25..372dbd9e4f 100755 --- a/contrib/scripts/kick_users.py +++ b/contrib/scripts/kick_users.py @@ -1,10 +1,12 @@ #!/usr/bin/env python from __future__ import print_function -from argparse import ArgumentParser + import json -import requests import sys import urllib +from argparse import ArgumentParser + +import requests try: raw_input diff --git a/debian/changelog b/debian/changelog index 3e83e9be9a..bde3b636ee 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,61 @@ +matrix-synapse-py3 (1.19.0ubuntu1) UNRELEASED; urgency=medium + + * Use Type=notify in systemd service + + -- Dexter Chua <dec41@srcf.net> Wed, 26 Aug 2020 12:41:36 +0000 + +matrix-synapse-py3 (1.19.1) stable; urgency=medium + + * New synapse release 1.19.1. + + -- Synapse Packaging team <packages@matrix.org> Thu, 27 Aug 2020 10:50:19 +0100 + +matrix-synapse-py3 (1.19.0) stable; urgency=medium + + [ Synapse Packaging team ] + * New synapse release 1.19.0. + + [ Aaron Raimist ] + * Fix outdated documentation for SYNAPSE_CACHE_FACTOR + + -- Synapse Packaging team <packages@matrix.org> Mon, 17 Aug 2020 14:06:42 +0100 + +matrix-synapse-py3 (1.18.0) stable; urgency=medium + + * New synapse release 1.18.0. + + -- Synapse Packaging team <packages@matrix.org> Thu, 30 Jul 2020 10:55:53 +0100 + +matrix-synapse-py3 (1.17.0) stable; urgency=medium + + * New synapse release 1.17.0. + + -- Synapse Packaging team <packages@matrix.org> Mon, 13 Jul 2020 10:20:31 +0100 + +matrix-synapse-py3 (1.16.1) stable; urgency=medium + + * New synapse release 1.16.1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 10 Jul 2020 12:09:24 +0100 + +matrix-synapse-py3 (1.17.0rc1) stable; urgency=medium + + * New synapse release 1.17.0rc1. + + -- Synapse Packaging team <packages@matrix.org> Thu, 09 Jul 2020 16:53:12 +0100 + +matrix-synapse-py3 (1.16.0) stable; urgency=medium + + * New synapse release 1.16.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 08 Jul 2020 11:03:48 +0100 + +matrix-synapse-py3 (1.15.2) stable; urgency=medium + + * New synapse release 1.15.2. + + -- Synapse Packaging team <packages@matrix.org> Thu, 02 Jul 2020 10:34:00 -0400 + matrix-synapse-py3 (1.15.1) stable; urgency=medium * New synapse release 1.15.1. diff --git a/debian/matrix-synapse.default b/debian/matrix-synapse.default index 65dc2f33d8..f402d73bbf 100644 --- a/debian/matrix-synapse.default +++ b/debian/matrix-synapse.default @@ -1,2 +1,2 @@ # Specify environment variables used when running Synapse -# SYNAPSE_CACHE_FACTOR=1 (default) +# SYNAPSE_CACHE_FACTOR=0.5 (default) diff --git a/debian/matrix-synapse.service b/debian/matrix-synapse.service index b0a8d72e6d..553babf549 100644 --- a/debian/matrix-synapse.service +++ b/debian/matrix-synapse.service @@ -2,7 +2,7 @@ Description=Synapse Matrix homeserver [Service] -Type=simple +Type=notify User=matrix-synapse WorkingDirectory=/var/lib/matrix-synapse EnvironmentFile=/etc/default/matrix-synapse diff --git a/debian/synctl.ronn b/debian/synctl.ronn index a73c832f62..1bad6094f3 100644 --- a/debian/synctl.ronn +++ b/debian/synctl.ronn @@ -46,19 +46,20 @@ Configuration file may be generated as follows: ## ENVIRONMENT * `SYNAPSE_CACHE_FACTOR`: - Synapse's architecture is quite RAM hungry currently - a lot of - recent room data and metadata is deliberately cached in RAM in - order to speed up common requests. This will be improved in - future, but for now the easiest way to either reduce the RAM usage - (at the risk of slowing things down) is to set the - SYNAPSE_CACHE_FACTOR environment variable. Roughly speaking, a - SYNAPSE_CACHE_FACTOR of 1.0 will max out at around 3-4GB of - resident memory - this is what we currently run the matrix.org - on. The default setting is currently 0.1, which is probably around - a ~700MB footprint. You can dial it down further to 0.02 if - desired, which targets roughly ~512MB. Conversely you can dial it - up if you need performance for lots of users and have a box with a - lot of RAM. + Synapse's architecture is quite RAM hungry currently - we deliberately + cache a lot of recent room data and metadata in RAM in order to speed up + common requests. We'll improve this in the future, but for now the easiest + way to either reduce the RAM usage (at the risk of slowing things down) + is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment + variable. The default is 0.5, which can be decreased to reduce RAM usage + in memory constrained enviroments, or increased if performance starts to + degrade. + + However, degraded performance due to a low cache factor, common on + machines with slow disks, often leads to explosions in memory use due + backlogged requests. In this case, reducing the cache factor will make + things worse. Instead, try increasing it drastically. 2.0 is a good + starting value. ## COPYRIGHT diff --git a/docker/Dockerfile b/docker/Dockerfile index 9a3cf7b3f5..27512f8600 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,34 +16,36 @@ ARG PYTHON_VERSION=3.7 ### ### Stage 0: builder ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder +FROM docker.io/python:${PYTHON_VERSION}-slim as builder # install the OS build deps - -RUN apk add \ - build-base \ - libffi-dev \ - libjpeg-turbo-dev \ - libressl-dev \ - libxslt-dev \ - linux-headers \ - postgresql-dev \ - zlib-dev - -# build things which have slow build steps, before we copy synapse, so that -# the layer can be cached. -# -# (we really just care about caching a wheel here, as the "pip install" below -# will install them again.) - +RUN apt-get update && apt-get install -y \ + build-essential \ + libffi-dev \ + libjpeg-dev \ + libpq-dev \ + libssl-dev \ + libwebp-dev \ + libxml++2.6-dev \ + libxslt1-dev \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Build dependencies that are not available as wheels, to speed up rebuilds RUN pip install --prefix="/install" --no-warn-script-location \ - cryptography \ - msgpack-python \ - pillow \ - pynacl + frozendict \ + jaeger-client \ + opentracing \ + prometheus-client \ + psycopg2 \ + pycparser \ + pyrsistent \ + pyyaml \ + simplejson \ + threadloop \ + thrift # now install synapse and all of the python deps to /install. - COPY synapse /synapse/synapse/ COPY scripts /synapse/scripts/ COPY MANIFEST.in README.rst setup.py synctl /synapse/ @@ -55,19 +57,16 @@ RUN pip install --prefix="/install" --no-warn-script-location \ ### Stage 1: runtime ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 +FROM docker.io/python:${PYTHON_VERSION}-slim -# xmlsec is required for saml support -RUN apk add --no-cache --virtual .runtime_deps \ - libffi \ - libjpeg-turbo \ - libressl \ - libxslt \ - libpq \ - zlib \ - su-exec \ - tzdata \ - xmlsec +RUN apt-get update && apt-get install -y \ + curl \ + gosu \ + libjpeg62-turbo \ + libpq5 \ + libwebp6 \ + xmlsec1 \ + && rm -rf /var/lib/apt/lists/* COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py @@ -78,3 +77,6 @@ VOLUME ["/data"] EXPOSE 8008/tcp 8009/tcp 8448/tcp ENTRYPOINT ["/start.py"] + +HEALTHCHECK --interval=1m --timeout=5s \ + CMD curl -fSs http://localhost:8008/health || exit 1 diff --git a/docker/README.md b/docker/README.md index 8c337149ca..d0da34778e 100644 --- a/docker/README.md +++ b/docker/README.md @@ -94,6 +94,21 @@ The following environment variables are supported in run mode: * `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`. * `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`. +## Generating an (admin) user + +After synapse is running, you may wish to create a user via `register_new_matrix_user`. + +This requires a `registration_shared_secret` to be set in your config file. Synapse +must be restarted to pick up this change. + +You can then call the script: + +``` +docker exec -it synapse register_new_matrix_user http://localhost:8008 -c /data/homeserver.yaml --help +``` + +Remember to remove the `registration_shared_secret` and restart if you no-longer need it. + ## TLS support The default configuration exposes a single HTTP port: http://localhost:8008. It @@ -147,3 +162,32 @@ docker build -t matrixdotorg/synapse -f docker/Dockerfile . You can choose to build a different docker image by changing the value of the `-f` flag to point to another Dockerfile. + +## Disabling the healthcheck + +If you are using a non-standard port or tls inside docker you can disable the healthcheck +whilst running the above `docker run` commands. + +``` + --no-healthcheck +``` +## Setting custom healthcheck on docker run + +If you wish to point the healthcheck at a different port with docker command, add the following + +``` + --health-cmd 'curl -fSs http://localhost:1234/health' +``` + +## Setting the healthcheck in docker-compose file + +You can add the following to set a custom healthcheck in a docker compose file. +You will need version >2.1 for this to work. + +``` +healthcheck: + test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"] + interval: 1m + timeout: 10s + retries: 3 +``` diff --git a/docker/conf/log.config b/docker/conf/log.config index ed418a57cd..491bbcc87a 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -4,16 +4,10 @@ formatters: precise: format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: diff --git a/docker/start.py b/docker/start.py index 2a25c9380e..9f08134158 100755 --- a/docker/start.py +++ b/docker/start.py @@ -120,7 +120,7 @@ def generate_config_from_template(config_dir, config_path, environ, ownership): if ownership is not None: subprocess.check_output(["chown", "-R", ownership, "/data"]) - args = ["su-exec", ownership] + args + args = ["gosu", ownership] + args subprocess.check_output(args) @@ -172,8 +172,8 @@ def run_generate_config(environ, ownership): # make sure that synapse has perms to write to the data dir. subprocess.check_output(["chown", ownership, data_dir]) - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) + args = ["gosu", ownership] + args + os.execv("/usr/sbin/gosu", args) else: os.execv("/usr/local/bin/python", args) @@ -189,7 +189,7 @@ def main(args, environ): ownership = "{}:{}".format(desired_uid, desired_gid) if ownership is None: - log("Will not perform chmod/su-exec as UserID already matches request") + log("Will not perform chmod/gosu as UserID already matches request") # In generate mode, generate a configuration and missing keys, then exit if mode == "generate": @@ -236,8 +236,8 @@ running with 'migrate_config'. See the README for more details. args = ["python", "-m", synapse_worker, "--config-path", config_path] if ownership is not None: - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) + args = ["gosu", ownership] + args + os.execv("/usr/sbin/gosu", args) else: os.execv("/usr/local/bin/python", args) diff --git a/docs/.sample_config_header.yaml b/docs/.sample_config_header.yaml index 35a591d042..8c9b31acdb 100644 --- a/docs/.sample_config_header.yaml +++ b/docs/.sample_config_header.yaml @@ -10,5 +10,16 @@ # homeserver.yaml. Instead, if you are starting from scratch, please generate # a fresh config using Synapse by following the instructions in INSTALL.md. +# Configuration options that take a time period can be set using a number +# followed by a letter. Letters have the following meanings: +# s = second +# m = minute +# h = hour +# d = day +# w = week +# y = year +# For example, setting redaction_retention_period: 5m would remove redacted +# messages from the database after 5 minutes, rather than 5 months. + ################################################################################ diff --git a/docs/ACME.md b/docs/ACME.md index f4c4740476..a7a498f575 100644 --- a/docs/ACME.md +++ b/docs/ACME.md @@ -12,13 +12,14 @@ introduced support for automatically provisioning certificates through In [March 2019](https://community.letsencrypt.org/t/end-of-life-plan-for-acmev1/88430), Let's Encrypt announced that they were deprecating version 1 of the ACME protocol, with the plan to disable the use of it for new accounts in -November 2019, and for existing accounts in June 2020. +November 2019, for new domains in June 2020, and for existing accounts and +domains in June 2021. Synapse doesn't currently support version 2 of the ACME protocol, which means that: * for existing installs, Synapse's built-in ACME support will continue - to work until June 2020. + to work until June 2021. * for new installs, this feature will not work at all. Either way, it is recommended to move from Synapse's ACME support diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md index 64ea7b6a64..ae01a543c6 100644 --- a/docs/admin_api/purge_room.md +++ b/docs/admin_api/purge_room.md @@ -5,6 +5,8 @@ This API will remove all trace of a room from your database. All local users must have left the room before it can be removed. +See also: [Delete Room API](rooms.md#delete-room-api) + The API is: ``` diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 624e7745ba..0f267d2b7b 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -318,3 +318,134 @@ Response: "state_events": 93534 } ``` + +# Room Members API + +The Room Members admin API allows server admins to get a list of all members of a room. + +The response includes the following fields: + +* `members` - A list of all the members that are present in the room, represented by their ids. +* `total` - Total number of members in the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms/<room_id>/members + +{} +``` + +Response: + +``` +{ + "members": [ + "@foo:matrix.org", + "@bar:matrix.org", + "@foobar:matrix.org + ], + "total": 3 +} +``` + +# Delete Room API + +The Delete Room admin API allows server admins to remove rooms from server +and block these rooms. +It is a combination and improvement of "[Shutdown room](shutdown_room.md)" +and "[Purge room](purge_room.md)" API. + +Shuts down a room. Moves all local users and room aliases automatically to a +new room if `new_room_user_id` is set. Otherwise local users only +leave the room without any information. + +The new room will be created with the user specified by the `new_room_user_id` parameter +as room administrator and will contain a message explaining what happened. Users invited +to the new room will have power level `-10` by default, and thus be unable to speak. + +If `block` is `True` it prevents new joins to the old room. + +This API will remove all trace of the old room from your database after removing +all local users. If `purge` is `true` (the default), all traces of the old room will +be removed from your database after removing all local users. If you do not want +this to happen, set `purge` to `false`. +Depending on the amount of history being purged a call to the API may take +several minutes or longer. + +The local server will only have the power to move local user and room aliases to +the new room. Users on other servers will be unaffected. + +The API is: + +```json +POST /_synapse/admin/v1/rooms/<room_id>/delete +``` + +with a body of: +```json +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", + "block": true, + "purge": true +} +``` + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see [README.rst](README.rst). + +A response body like the following is returned: + +```json +{ + "kicked_users": [ + "@foobar:example.com" + ], + "failed_to_kick_users": [], + "local_aliases": [ + "#badroom:example.com", + "#evilsaloon:example.com" + ], + "new_room_id": "!newroomid:example.com" +} +``` + +## Parameters + +The following parameters should be set in the URL: + +* `room_id` - The ID of the room. + +The following JSON body parameters are available: + +* `new_room_user_id` - Optional. If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be moved into that + room. If not set, no new room will be created and the users will just be removed + from the old room. The user ID must be on the local server, but does not necessarily + have to belong to a registered user. +* `room_name` - Optional. A string representing the name of the room that new users will be + invited to. Defaults to `Content Violation Notification` +* `message` - Optional. A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly convey why the + original room was shut down. Defaults to `Sharing illegal content on this server + is not permitted and rooms in violation will be blocked.` +* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing + future attempts to join the room. Defaults to `false`. +* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. + Defaults to `true`. + +The JSON body must not be empty. The body must be at least `{}`. + +## Response + +The following fields are returned in the JSON response body: + +* `kicked_users` - An array of users (`user_id`) that were kicked. +* `failed_to_kick_users` - An array of users (`user_id`) that that were not kicked. +* `local_aliases` - An array of strings representing the local aliases that were migrated from + the old room to the new. +* `new_room_id` - A string representing the room ID of the new room. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md index 54ce1cd234..9b1cb1c184 100644 --- a/docs/admin_api/shutdown_room.md +++ b/docs/admin_api/shutdown_room.md @@ -10,6 +10,8 @@ disallow any further invites or joins. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. +See also: [Delete Room API](rooms.md#delete-room-api) + ## API You will need to authenticate with an access token for an admin user. @@ -31,7 +33,7 @@ You will need to authenticate with an access token for an admin user. * `message` - Optional. A string containing the first message that will be sent as `new_room_user_id` in the new room. Ideally this will clearly convey why the original room was shut down. - + If not specified, the default value of `room_name` is "Content Violation Notification". The default value of `message` is "Sharing illegal content on othis server is not permitted and rooms in violation will be blocked." @@ -70,3 +72,30 @@ Response: "new_room_id": "!newroomid:example.com", }, ``` + +## Undoing room shutdowns + +*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, +the structure can and does change without notice. + +First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it +never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible +to recover at all: + +* If the room was invite-only, your users will need to be re-invited. +* If the room no longer has any members at all, it'll be impossible to rejoin. +* The first user to rejoin will have to do so via an alias on a different server. + +With all that being said, if you still want to try and recover the room: + +1. For safety reasons, shut down Synapse. +2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` + * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. + * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. +3. Restart Synapse. + +You will have to manually handle, if you so choose, the following: + +* Aliases that would have been redirected to the Content Violation room. +* Users that would have been booted from the room (and will have been force-joined to the Content Violation room). +* Removal of the Content Violation room if desired. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 7b030a6285..e21c78a9c6 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -91,10 +91,14 @@ Body parameters: - ``admin``, optional, defaults to ``false``. -- ``deactivated``, optional, defaults to ``false``. +- ``deactivated``, optional. If unspecified, deactivation state will be left + unchanged on existing accounts and set to ``false`` for new accounts. If the user already exists then optional parameters default to the current value. +In order to re-activate an account ``deactivated`` must be set to ``false``. If +users do not login via single-sign-on, a new ``password`` must be provided. + List Accounts ============= @@ -104,7 +108,7 @@ The api is:: GET /_synapse/admin/v2/users?from=0&limit=10&guests=false -To use it, you will need to authenticate by providing an `access_token` for a +To use it, you will need to authenticate by providing an ``access_token`` for a server admin: see `README.rst <README.rst>`_. The parameter ``from`` is optional but used for pagination, denoting the @@ -115,8 +119,11 @@ from a previous call. The parameter ``limit`` is optional but is used for pagination, denoting the maximum number of items to return in this call. Defaults to ``100``. -The parameter ``user_id`` is optional and filters to only users with user IDs -that contain this value. +The parameter ``user_id`` is optional and filters to only return users with user IDs +that contain this value. This parameter is ignored when using the ``name`` parameter. + +The parameter ``name`` is optional and filters to only return users with user ID localparts +**or** displaynames that contain this value. The parameter ``guests`` is optional and if ``false`` will **exclude** guest users. Defaults to ``true`` to include guest users. @@ -207,9 +214,11 @@ Deactivate Account This API deactivates an account. It removes active access tokens, resets the password, and deletes third-party IDs (to prevent the user requesting a -password reset). It can also mark the user as GDPR-erased (stopping their data -from distributed further, and deleting it entirely if there are no other -references to it). +password reset). + +It can also mark the user as GDPR-erased. This means messages sent by the +user will still be visible by anyone that was in the room when these messages +were sent, but hidden from users joining the room afterwards. The api is:: diff --git a/docs/federate.md b/docs/federate.md index a0786b9cf7..b15cd724d1 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -47,6 +47,18 @@ you invite them to. This can be caused by an incorrectly-configured reverse proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly configure a reverse proxy. +### Known issues + +**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features +in the HTTP library used by Synapse, 308 redirects are currently not followed by +federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This +may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`), +and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect +codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871) +might be helpful. For other users, switching to a `301 Moved Permanently` code may be +an option. 308 redirect codes will be supported properly in a future +release of Synapse. + ## Running a demo federation of Synapses If you want to get up and running quickly with a trio of homeservers in a diff --git a/docs/jwt.md b/docs/jwt.md new file mode 100644 index 0000000000..5be9fd26e3 --- /dev/null +++ b/docs/jwt.md @@ -0,0 +1,97 @@ +# JWT Login Type + +Synapse comes with a non-standard login type to support +[JSON Web Tokens](https://en.wikipedia.org/wiki/JSON_Web_Token). In general the +documentation for +[the login endpoint](https://matrix.org/docs/spec/client_server/r0.6.1#login) +is still valid (and the mechanism works similarly to the +[token based login](https://matrix.org/docs/spec/client_server/r0.6.1#token-based)). + +To log in using a JSON Web Token, clients should submit a `/login` request as +follows: + +```json +{ + "type": "org.matrix.login.jwt", + "token": "<jwt>" +} +``` + +Note that the login type of `m.login.jwt` is supported, but is deprecated. This +will be removed in a future version of Synapse. + +The `token` field should include the JSON web token with the following claims: + +* The `sub` (subject) claim is required and should encode the local part of the + user ID. +* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`) + claims are optional, but validated if present. +* The issuer (`iss`) claim is optional, but required and validated if configured. +* The audience (`aud`) claim is optional, but required and validated if configured. + Providing the audience claim when not configured will cause validation to fail. + +In the case that the token is not valid, the homeserver must respond with +`403 Forbidden` and an error code of `M_FORBIDDEN`. + +As with other login types, there are additional fields (e.g. `device_id` and +`initial_device_display_name`) which can be included in the above request. + +## Preparing Synapse + +The JSON Web Token integration in Synapse uses the +[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed +as follows: + + * The relevant libraries are included in the Docker images and Debian packages + provided by `matrix.org` so no further action is needed. + + * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip + install synapse[pyjwt]` to install the necessary dependencies. + + * For other installation mechanisms, see the documentation provided by the + maintainer. + +To enable the JSON web token integration, you should then add an `jwt_config` section +to your configuration file (or uncomment the `enabled: true` line in the +existing section). See [sample_config.yaml](./sample_config.yaml) for some +sample settings. + +## How to test JWT as a developer + +Although JSON Web Tokens are typically generated from an external server, the +examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly. + +1. Configure Synapse with JWT logins, note that this example uses a pre-shared + secret and an algorithm of HS256: + + ```yaml + jwt_config: + enabled: true + secret: "my-secret-token" + algorithm: "HS256" + ``` +2. Generate a JSON web token: + + ```bash + $ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user + eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc + ``` +3. Query for the login types and ensure `org.matrix.login.jwt` is there: + + ```bash + curl http://localhost:8080/_matrix/client/r0/login + ``` +4. Login used the generated JSON web token from above: + + ```bash + $ curl http://localhost:8082/_matrix/client/r0/login -X POST \ + --data '{"type":"org.matrix.login.jwt","token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc"}' + { + "access_token": "<access token>", + "device_id": "ACBDEFGHI", + "home_server": "localhost:8080", + "user_id": "@test-user:localhost:8480" + } + ``` + +You should now be able to use the returned access token to query the client API. diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md index cf69938a2a..b386ec91c1 100644 --- a/docs/metrics-howto.md +++ b/docs/metrics-howto.md @@ -27,7 +27,7 @@ different thread to Synapse. This can make it more resilient to heavy load meaning metrics cannot be retrieved, and can be exposed to just internal networks easier. The served metrics are available - over HTTP only, and will be available at `/`. + over HTTP only, and will be available at `/_synapse/metrics`. Add a new listener to homeserver.yaml: diff --git a/docs/openid.md b/docs/openid.md index 688379ddd9..70b37f858b 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -23,6 +23,7 @@ such as [Github][github-idp]. [auth0]: https://auth0.com/ [okta]: https://www.okta.com/ [dex-idp]: https://github.com/dexidp/dex +[keycloak-idp]: https://www.keycloak.org/docs/latest/server_admin/#sso-protocols [hydra]: https://www.ory.sh/docs/hydra/ [github-idp]: https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps @@ -89,7 +90,50 @@ oidc_config: localpart_template: "{{ user.name }}" display_name_template: "{{ user.name|capitalize }}" ``` +### [Keycloak][keycloak-idp] +[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. + +Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm. + +1. Click `Clients` in the sidebar and click `Create` + +2. Fill in the fields as below: + +| Field | Value | +|-----------|-----------| +| Client ID | `synapse` | +| Client Protocol | `openid-connect` | + +3. Click `Save` +4. Fill in the fields as below: + +| Field | Value | +|-----------|-----------| +| Client ID | `synapse` | +| Enabled | `On` | +| Client Protocol | `openid-connect` | +| Access Type | `confidential` | +| Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` | + +5. Click `Save` +6. On the Credentials tab, update the fields: + +| Field | Value | +|-------|-------| +| Client Authenticator | `Client ID and Secret` | + +7. Click `Regenerate Secret` +8. Copy Secret + +```yaml +oidc_config: + enabled: true + issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" + client_id: "synapse" + client_secret: "copy secret generated from above" + scopes: ["openid", "profile"] +``` ### [Auth0][auth0] 1. Create a regular web application for Synapse diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index 5d9ae67041..7d98d9f255 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -14,107 +14,109 @@ password auth provider module implementations: * [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/) * [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth) +* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider) ## Required methods Password auth provider classes must provide the following methods: -*class* `SomeProvider.parse_config`(*config*) +* `parse_config(config)` + This method is passed the `config` object for this module from the + homeserver configuration file. -> This method is passed the `config` object for this module from the -> homeserver configuration file. -> -> It should perform any appropriate sanity checks on the provided -> configuration, and return an object which is then passed into -> `__init__`. + It should perform any appropriate sanity checks on the provided + configuration, and return an object which is then passed into -*class* `SomeProvider`(*config*, *account_handler*) + This method should have the `@staticmethod` decoration. -> The constructor is passed the config object returned by -> `parse_config`, and a `synapse.module_api.ModuleApi` object which -> allows the password provider to check if accounts exist and/or create -> new ones. +* `__init__(self, config, account_handler)` + + The constructor is passed the config object returned by + `parse_config`, and a `synapse.module_api.ModuleApi` object which + allows the password provider to check if accounts exist and/or create + new ones. ## Optional methods -Password auth provider classes may optionally provide the following -methods. - -*class* `SomeProvider.get_db_schema_files`() - -> This method, if implemented, should return an Iterable of -> `(name, stream)` pairs of database schema files. Each file is applied -> in turn at initialisation, and a record is then made in the database -> so that it is not re-applied on the next start. - -`someprovider.get_supported_login_types`() - -> This method, if implemented, should return a `dict` mapping from a -> login type identifier (such as `m.login.password`) to an iterable -> giving the fields which must be provided by the user in the submission -> to the `/login` api. These fields are passed in the `login_dict` -> dictionary to `check_auth`. -> -> For example, if a password auth provider wants to implement a custom -> login type of `com.example.custom_login`, where the client is expected -> to pass the fields `secret1` and `secret2`, the provider should -> implement this method and return the following dict: -> -> {"com.example.custom_login": ("secret1", "secret2")} - -`someprovider.check_auth`(*username*, *login_type*, *login_dict*) - -> This method is the one that does the real work. If implemented, it -> will be called for each login attempt where the login type matches one -> of the keys returned by `get_supported_login_types`. -> -> It is passed the (possibly UNqualified) `user` provided by the client, -> the login type, and a dictionary of login secrets passed by the -> client. -> -> The method should return a Twisted `Deferred` object, which resolves -> to the canonical `@localpart:domain` user id if authentication is -> successful, and `None` if not. -> -> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in -> which case the second field is a callback which will be called with -> the result from the `/login` call (including `access_token`, -> `device_id`, etc.) - -`someprovider.check_3pid_auth`(*medium*, *address*, *password*) - -> This method, if implemented, is called when a user attempts to -> register or log in with a third party identifier, such as email. It is -> passed the medium (ex. "email"), an address (ex. -> "<jdoe@example.com>") and the user's password. -> -> The method should return a Twisted `Deferred` object, which resolves -> to a `str` containing the user's (canonical) User ID if -> authentication was successful, and `None` if not. -> -> As with `check_auth`, the `Deferred` may alternatively resolve to a -> `(user_id, callback)` tuple. - -`someprovider.check_password`(*user_id*, *password*) - -> This method provides a simpler interface than -> `get_supported_login_types` and `check_auth` for password auth -> providers that just want to provide a mechanism for validating -> `m.login.password` logins. -> -> Iif implemented, it will be called to check logins with an -> `m.login.password` login type. It is passed a qualified -> `@localpart:domain` user id, and the password provided by the user. -> -> The method should return a Twisted `Deferred` object, which resolves -> to `True` if authentication is successful, and `False` if not. - -`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*) - -> This method, if implemented, is called when a user logs out. It is -> passed the qualified user ID, the ID of the deactivated device (if -> any: access tokens are occasionally created without an associated -> device ID), and the (now deactivated) access token. -> -> It may return a Twisted `Deferred` object; the logout request will -> wait for the deferred to complete but the result is ignored. +Password auth provider classes may optionally provide the following methods: + +* `get_db_schema_files(self)` + + This method, if implemented, should return an Iterable of + `(name, stream)` pairs of database schema files. Each file is applied + in turn at initialisation, and a record is then made in the database + so that it is not re-applied on the next start. + +* `get_supported_login_types(self)` + + This method, if implemented, should return a `dict` mapping from a + login type identifier (such as `m.login.password`) to an iterable + giving the fields which must be provided by the user in the submission + to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). + These fields are passed in the `login_dict` dictionary to `check_auth`. + + For example, if a password auth provider wants to implement a custom + login type of `com.example.custom_login`, where the client is expected + to pass the fields `secret1` and `secret2`, the provider should + implement this method and return the following dict: + + ```python + {"com.example.custom_login": ("secret1", "secret2")} + ``` + +* `check_auth(self, username, login_type, login_dict)` + + This method does the real work. If implemented, it + will be called for each login attempt where the login type matches one + of the keys returned by `get_supported_login_types`. + + It is passed the (possibly unqualified) `user` field provided by the client, + the login type, and a dictionary of login secrets passed by the + client. + + The method should return an `Awaitable` object, which resolves + to the canonical `@localpart:domain` user ID if authentication is + successful, and `None` if not. + + Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in + which case the second field is a callback which will be called with + the result from the `/login` call (including `access_token`, + `device_id`, etc.) + +* `check_3pid_auth(self, medium, address, password)` + + This method, if implemented, is called when a user attempts to + register or log in with a third party identifier, such as email. It is + passed the medium (ex. "email"), an address (ex. + "<jdoe@example.com>") and the user's password. + + The method should return an `Awaitable` object, which resolves + to a `str` containing the user's (canonical) User id if + authentication was successful, and `None` if not. + + As with `check_auth`, the `Awaitable` may alternatively resolve to a + `(user_id, callback)` tuple. + +* `check_password(self, user_id, password)` + + This method provides a simpler interface than + `get_supported_login_types` and `check_auth` for password auth + providers that just want to provide a mechanism for validating + `m.login.password` logins. + + If implemented, it will be called to check logins with an + `m.login.password` login type. It is passed a qualified + `@localpart:domain` user id, and the password provided by the user. + + The method should return an `Awaitable` object, which resolves + to `True` if authentication is successful, and `False` if not. + +* `on_logged_out(self, user_id, device_id, access_token)` + + This method, if implemented, is called when a user logs out. It is + passed the qualified user ID, the ID of the deactivated device (if + any: access tokens are occasionally created without an associated + device ID), and the (now deactivated) access token. + + It may return an `Awaitable` object; the logout request will + wait for the `Awaitable` to complete, but the result is ignored. diff --git a/docs/postgres.md b/docs/postgres.md index 70fe29cdcc..e71a1975d8 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -188,6 +188,9 @@ to do step 2. It is safe to at any time kill the port script and restart it. +Note that the database may take up significantly more (25% - 100% more) +space on disk after porting to Postgres. + ### Using the port script Firstly, shut down the currently running synapse server and copy its diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index cbb8269568..fd48ba0874 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -3,7 +3,7 @@ It is recommended to put a reverse proxy such as [nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html), [Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html), -[Caddy](https://caddyserver.com/docs/proxy) or +[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or [HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root @@ -38,6 +38,11 @@ the reverse proxy and the homeserver. server { listen 443 ssl; listen [::]:443 ssl; + + # For the federation port + listen 8448 ssl default_server; + listen [::]:8448 ssl default_server; + server_name matrix.example.com; location /_matrix { @@ -48,17 +53,6 @@ server { client_max_body_size 10M; } } - -server { - listen 8448 ssl default_server; - listen [::]:8448 ssl default_server; - server_name example.com; - - location / { - proxy_pass http://localhost:8008; - proxy_set_header X-Forwarded-For $remote_addr; - } -} ``` **NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will @@ -145,3 +139,10 @@ client IP addresses are recorded correctly. Having done so, you can then use `https://matrix.example.com` (instead of `https://matrix.example.com:8448`) as the "Custom server" when connecting to Synapse from a client. + + +## Health check endpoint + +Synapse exposes a health check endpoint for use by reverse proxies. +Each configured HTTP listener has a `/health` endpoint which always returns +200 OK (and doesn't get logged). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a60b15fa4e..e8348ce1ff 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -10,6 +10,17 @@ # homeserver.yaml. Instead, if you are starting from scratch, please generate # a fresh config using Synapse by following the instructions in INSTALL.md. +# Configuration options that take a time period can be set using a number +# followed by a letter. Letters have the following meanings: +# s = second +# m = minute +# h = hour +# d = day +# w = week +# y = year +# For example, setting redaction_retention_period: 5m would remove redacted +# messages from the database after 5 minutes, rather than 5 months. + ################################################################################ # Configuration file for Synapse. @@ -102,7 +113,9 @@ pid_file: DATADIR/homeserver.pid #gc_thresholds: [700, 10, 10] # Set the limit on the returned events in the timeline in the get -# and sync operations. The default value is -1, means no upper limit. +# and sync operations. The default value is 100. -1 means no upper limit. +# +# Uncomment the following to increase the limit to 5000. # #filter_timeline_limit: 5000 @@ -118,38 +131,6 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false -# Restrict federation to the following whitelist of domains. -# N.B. we recommend also firewalling your federation listener to limit -# inbound federation traffic as early as possible, rather than relying -# purely on this application-layer restriction. If not specified, the -# default is to whitelist everything. -# -#federation_domain_whitelist: -# - lon.example.com -# - nyc.example.com -# - syd.example.com - -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. -# -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -178,7 +159,7 @@ federation_ip_range_blacklist: # names: a list of names of HTTP resources. See below for a list of # valid resource names. # -# compress: set to true to enable HTTP comression for this resource. +# compress: set to true to enable HTTP compression for this resource. # # additional_resources: Only valid for an 'http' listener. A map of # additional endpoints which should be loaded via dynamic modules. @@ -283,7 +264,7 @@ listeners: # number of monthly active users. # # 'limit_usage_by_mau' disables/enables monthly active user blocking. When -# anabled and a limit is reached the server returns a 'ResourceLimitError' +# enabled and a limit is reached the server returns a 'ResourceLimitError' # with error type Codes.RESOURCE_LIMIT_EXCEEDED # # 'max_mau_value' is the hard limit of monthly active users above which @@ -344,6 +325,10 @@ limit_remote_rooms: # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # @@ -393,11 +378,10 @@ retention: # min_lifetime: 1d # max_lifetime: 1y - # Retention policy limits. If set, a user won't be able to send a - # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' - # that's not within this range. This is especially useful in closed federations, - # in which server admins can make sure every federating server applies the same - # rules. + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. # #allowed_lifetime_min: 1d #allowed_lifetime_max: 1y @@ -423,12 +407,19 @@ retention: # (e.g. every 12h), but not want that purge to be performed by a job that's # iterating over every room it knows, which could be heavy on the server. # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # #purge_jobs: - # - shortest_max_lifetime: 1d - # longest_max_lifetime: 3d + # - longest_max_lifetime: 3d # interval: 12h # - shortest_max_lifetime: 3d - # longest_max_lifetime: 1y # interval: 1d # Inhibits the /requestToken endpoints from returning an error that might leak @@ -608,6 +599,39 @@ acme: +# Restrict federation to the following whitelist of domains. +# N.B. we recommend also firewalling your federation listener to limit +# inbound federation traffic as early as possible, rather than relying +# purely on this application-layer restriction. If not specified, the +# default is to whitelist everything. +# +#federation_domain_whitelist: +# - lon.example.com +# - nyc.example.com +# - syd.example.com + +# Prevent federation requests from being sent to the following +# blacklist IP address CIDR ranges. If this option is not specified, or +# specified with an empty list, no ip range blacklist will be enforced. +# +# As of Synapse v1.4.0 this option also affects any outbound requests to identity +# servers provided by user input. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + + ## Caching ## # Caching can be configured through the following options. @@ -682,7 +706,7 @@ caches: #database: # name: psycopg2 # args: -# user: synapse +# user: synapse_user # password: secretpassword # database: synapse # host: localhost @@ -728,6 +752,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. +# - two for ratelimiting number of rooms a user can join, "local" for when +# users are joining rooms the server is already in (this is cheap) vs +# "remote" for when users are trying to join rooms not on the server (which +# can be more expensive) # # The defaults are as shown below. # @@ -753,6 +781,14 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_admin_redaction: # per_second: 1 # burst_count: 50 +# +#rc_joins: +# local: +# per_second: 0.1 +# burst_count: 3 +# remote: +# per_second: 0.01 +# burst_count: 3 # Ratelimiting settings for incoming federation @@ -1147,24 +1183,6 @@ account_validity: # #default_identity_server: https://matrix.org -# The list of identity servers trusted to verify third party -# identifiers by this server. -# -# Also defines the ID server which will be called when an account is -# deactivated (one will be picked arbitrarily). -# -# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity -# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a -# background migration script, informing itself that the identity server all of its -# 3PIDs have been bound to is likely one of the below. -# -# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and -# it is now solely used for the purposes of the background migration script, and can be -# removed once it has run. -#trusted_third_party_id_servers: -# - matrix.org -# - vector.im - # Handle threepid (email/phone etc) registration and password resets through a set of # *trusted* identity servers. Note that this allows the configured identity server to # reset passwords for accounts! @@ -1215,7 +1233,11 @@ account_threepid_delegates: #enable_3pid_changes: false # Users who register on this homeserver will automatically be joined -# to these rooms +# to these rooms. +# +# By default, any room aliases included in this list will be created +# as a publicly joinable room when the first user registers for the +# homeserver. This behaviour can be customised with the settings below. # #auto_join_rooms: # - "#example:example.com" @@ -1223,10 +1245,62 @@ account_threepid_delegates: # Where auto_join_rooms are specified, setting this flag ensures that the # the rooms exist by creating them when the first user on the # homeserver registers. +# +# By default the auto-created rooms are publicly joinable from any federated +# server. Use the autocreate_auto_join_rooms_federated and +# autocreate_auto_join_room_preset settings below to customise this behaviour. +# # Setting to false means that if the rooms are not manually created, # users cannot be auto-joined since they do not exist. # -#autocreate_auto_join_rooms: true +# Defaults to true. Uncomment the following line to disable automatically +# creating auto-join rooms. +# +#autocreate_auto_join_rooms: false + +# Whether the auto_join_rooms that are auto-created are available via +# federation. Only has an effect if autocreate_auto_join_rooms is true. +# +# Note that whether a room is federated cannot be modified after +# creation. +# +# Defaults to true: the room will be joinable from other servers. +# Uncomment the following to prevent users from other homeservers from +# joining these rooms. +# +#autocreate_auto_join_rooms_federated: false + +# The room preset to use when auto-creating one of auto_join_rooms. Only has an +# effect if autocreate_auto_join_rooms is true. +# +# This can be one of "public_chat", "private_chat", or "trusted_private_chat". +# If a value of "private_chat" or "trusted_private_chat" is used then +# auto_join_mxid_localpart must also be configured. +# +# Defaults to "public_chat", meaning that the room is joinable by anyone, including +# federated servers if autocreate_auto_join_rooms_federated is true (the default). +# Uncomment the following to require an invitation to join these rooms. +# +#autocreate_auto_join_room_preset: private_chat + +# The local part of the user id which is used to create auto_join_rooms if +# autocreate_auto_join_rooms is true. If this is not provided then the +# initial user account that registers will be used to create the rooms. +# +# The user id is also used to invite new users to any auto-join rooms which +# are set to invite-only. +# +# It *must* be configured if autocreate_auto_join_room_preset is set to +# "private_chat" or "trusted_private_chat". +# +# Note that this must be specified in order for new users to be correctly +# invited to any auto-join rooms which have been set to invite-only (either +# at the time of creation or subsequently). +# +# Note that, if the room already exists, this user must be joined and +# have the appropriate permissions to invite new members. +# +#auto_join_mxid_localpart: system # When auto_join_rooms is specified, setting this flag to false prevents # guest accounts from being automatically joined to the rooms. @@ -1459,7 +1533,7 @@ saml2_config: # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. - # The default is 5 minutes. + # The default is 15 minutes. # #saml_session_lifetime: 5m @@ -1514,6 +1588,17 @@ saml2_config: # #grandfathered_mxid_source_attribute: upn + # It is possible to configure Synapse to only allow logins if SAML attributes + # match particular values. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. + # + #attribute_requirements: + # - attribute: userGroup + # value: "staff" + # - attribute: department + # value: "sales" + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # @@ -1544,7 +1629,7 @@ saml2_config: # use an OpenID Connect Provider for authentication, instead of its internal # password database. # -# See https://github.com/matrix-org/synapse/blob/master/openid.md. +# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md. # oidc_config: # Uncomment the following to enable authorization against an OpenID Connect @@ -1753,12 +1838,60 @@ sso: #template_dir: "res/templates" -# The JWT needs to contain a globally unique "sub" (subject) claim. +# JSON web token integration. The following settings can be used to make +# Synapse JSON web tokens for authentication, instead of its internal +# password database. +# +# Each JSON Web Token needs to contain a "sub" (subject) claim, which is +# used as the localpart of the mxid. +# +# Additionally, the expiration time ("exp"), not before time ("nbf"), +# and issued at ("iat") claims are validated if present. +# +# Note that this is a non-standard login type and client support is +# expected to be non-existant. +# +# See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. # #jwt_config: -# enabled: true -# secret: "a secret" -# algorithm: "HS256" + # Uncomment the following to enable authorization using JSON web + # tokens. Defaults to false. + # + #enabled: true + + # This is either the private shared secret or the public key used to + # decode the contents of the JSON web token. + # + # Required if 'enabled' is true. + # + #secret: "provided-by-your-issuer" + + # The algorithm used to sign the JSON web token. + # + # Supported algorithms are listed at + # https://pyjwt.readthedocs.io/en/latest/algorithms.html + # + # Required if 'enabled' is true. + # + #algorithm: "provided-by-your-issuer" + + # The issuer to validate the "iss" claim against. + # + # Optional, if provided the "iss" claim will be required and + # validated for all JSON web tokens. + # + #issuer: "provided-by-your-issuer" + + # A list of audiences to validate the "aud" claim against. + # + # Optional, if provided the "aud" claim will be required and + # validated for all JSON web tokens. + # + # Note that if the "aud" claim is included in a JSON web token then + # validation will fail without configuring audiences. + # + #audiences: + # - "provided-by-your-issuer" password_config: @@ -1849,8 +1982,8 @@ email: # #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" - # app_name defines the default value for '%(app)s' in notif_from. It - # defaults to 'Matrix'. + # app_name defines the default value for '%(app)s' in notif_from and email + # subjects. It defaults to 'Matrix'. # #app_name: my_branded_matrix_server @@ -1880,9 +2013,7 @@ email: # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # @@ -1919,6 +2050,73 @@ email: # #template_dir: "res/templates" + # Subjects to use when sending emails from Synapse. + # + # The placeholder '%(app)s' will be replaced with the value of the 'app_name' + # setting above, or by a value dictated by the Matrix client application. + # + # If a subject isn't overridden in this configuration file, the value used as + # its example will be used. + # + #subjects: + + # Subjects for notification emails. + # + # On top of the '%(app)s' placeholder, these can use the following + # placeholders: + # + # * '%(person)s', which will be replaced by the display name of the user(s) + # that sent the message(s), e.g. "Alice and Bob". + # * '%(room)s', which will be replaced by the name of the room the + # message(s) have been sent to, e.g. "My super room". + # + # See the example provided for each setting to see which placeholder can be + # used and how to use them. + # + # Subject to use to notify about one message from one or more user(s) in a + # room which has a name. + #message_from_person_in_room: "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room..." + # + # Subject to use to notify about one message from one or more user(s) in a + # room which doesn't have a name. + #message_from_person: "[%(app)s] You have a message on %(app)s from %(person)s..." + # + # Subject to use to notify about multiple messages from one or more users in + # a room which doesn't have a name. + #messages_from_person: "[%(app)s] You have messages on %(app)s from %(person)s..." + # + # Subject to use to notify about multiple messages in a room which has a + # name. + #messages_in_room: "[%(app)s] You have messages on %(app)s in the %(room)s room..." + # + # Subject to use to notify about multiple messages in multiple rooms. + #messages_in_room_and_others: "[%(app)s] You have messages on %(app)s in the %(room)s room and others..." + # + # Subject to use to notify about multiple messages from multiple persons in + # multiple rooms. This is similar to the setting above except it's used when + # the room in which the notification was triggered has no name. + #messages_from_person_and_others: "[%(app)s] You have messages on %(app)s from %(person)s and others..." + # + # Subject to use to notify about an invite to a room which has a name. + #invite_from_person_to_room: "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s..." + # + # Subject to use to notify about an invite to a room which doesn't have a + # name. + #invite_from_person: "[%(app)s] %(person)s has invited you to chat on %(app)s..." + + # Subject for emails related to account administration. + # + # On top of the '%(app)s' placeholder, these one can use the + # '%(server_name)s' placeholder, which will be replaced by the value of the + # 'server_name' setting in your Synapse configuration. + # + # Subject to use when sending a password reset email. + #password_reset: "[%(server_name)s] Password reset" + # + # Subject to use when sending a verification email to assert an address's + # ownership. + #email_validation: "[%(server_name)s] Validate your email" + # Password providers allow homeserver administrators to integrate # their Synapse installation with existing authentication methods @@ -1978,6 +2176,26 @@ spam_checker: # example_stop_events_from: ['@bad:example.com'] +## Rooms ## + +# Controls whether locally-created rooms should be end-to-end encrypted by +# default. +# +# Possible options are "all", "invite", and "off". They are defined as: +# +# * "all": any locally-created room +# * "invite": any room created with the "private_chat" or "trusted_private_chat" +# room creation presets +# * "off": this option will take no effect +# +# The default value is "off". +# +# Note that this option will only affect rooms created after it is set. It +# will also not affect rooms created by other servers. +# +#encryption_enabled_by_default_for_room_type: invite + + # Uncomment to allow non-server-admin users to create groups on this server # #enable_group_creation: true @@ -2209,3 +2427,57 @@ opentracing: # # logging: # false + + +## Workers ## + +# Disables sending of outbound federation transactions on the main process. +# Uncomment if using a federation sender worker. +# +#send_federation: false + +# It is possible to run multiple federation sender workers, in which case the +# work is balanced across them. +# +# This configuration must be shared between all federation sender workers, and if +# changed all federation sender workers must be stopped at the same time and then +# started, to ensure that all instances are running with the same config (otherwise +# events may be dropped). +# +#federation_sender_instances: +# - federation_sender1 + +# When using workers this should be a map from `worker_name` to the +# HTTP replication listener of the worker, if configured. +# +#instance_map: +# worker1: +# host: localhost +# port: 8034 + +# Experimental: When using workers you can define which workers should +# handle event persistence and typing notifications. Any worker +# specified here must also be in the `instance_map`. +# +#stream_writers: +# events: worker1 +# typing: worker1 + + +# Configuration for Redis when using workers. This *must* be enabled when +# using workers (unless using old style direct TCP configuration). +# +redis: + # Uncomment the below to enable Redis support. + # + #enabled: true + + # Optional host and port to use to connect to redis. Defaults to + # localhost and 6379 + # + #host: localhost + #port: 6379 + + # Optional password if configured on the Redis instance + # + #password: <secret_password> diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 1a2739455e..55a48a9ed6 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -11,24 +11,33 @@ formatters: precise: format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: file: - class: logging.handlers.RotatingFileHandler + class: logging.handlers.TimedRotatingFileHandler formatter: precise filename: /var/log/matrix-synapse/homeserver.log - maxBytes: 104857600 - backupCount: 10 - filters: [context] + when: midnight + backupCount: 3 # Does not include the current log file. encoding: utf8 + + # Default to buffering writes to log file for efficiency. This means that + # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR + # logs will still be flushed immediately. + buffer: + class: logging.handlers.MemoryHandler + target: file + # The capacity is the number of log lines that are buffered before + # being written to disk. Increasing this will lead to better + # performance, at the expensive of it taking longer for log lines to + # be written to disk. + capacity: 10 + flushLevel: 30 # Flush for WARNING logs as well + + # A handler that writes logs to stderr. Unused by default, but can be used + # instead of "buffer" and "file" in the logger handlers. console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: @@ -36,8 +45,23 @@ loggers: # information such as access tokens. level: INFO + twisted: + # We send the twisted logging directly to the file handler, + # to work around https://github.com/matrix-org/synapse/issues/3471 + # when using "buffer" logger. Use "console" to log to stderr instead. + handlers: [file] + propagate: false + root: level: INFO - handlers: [file, console] + + # Write logs to the `buffer` handler, which will buffer them together in memory, + # then write them to a file. + # + # Replace "buffer" with "console" to log to stderr instead. (Note that you'll + # also need to update the configuation for the `twisted` logger above, in + # this case.) + # + handlers: [buffer] disable_existing_loggers: false diff --git a/docs/synctl_workers.md b/docs/synctl_workers.md new file mode 100644 index 0000000000..8da4a31852 --- /dev/null +++ b/docs/synctl_workers.md @@ -0,0 +1,32 @@ +### Using synctl with workers + +If you want to use `synctl` to manage your synapse processes, you will need to +create an an additional configuration file for the main synapse process. That +configuration should look like this: + +```yaml +worker_app: synapse.app.homeserver +``` + +Additionally, each worker app must be configured with the name of a "pid file", +to which it will write its process ID when it starts. For example, for a +synchrotron, you might write: + +```yaml +worker_pid_file: /home/matrix/synapse/worker1.pid +``` + +Finally, to actually run your worker-based synapse, you must pass synctl the `-a` +commandline option to tell it to operate on all the worker configurations found +in the given directory, e.g.: + + synctl -a $CONFIG/workers start + +Currently one should always restart all workers when restarting or upgrading +synapse, unless you explicitly know it's safe not to. For instance, restarting +synapse without restarting all the synchrotrons may result in broken typing +notifications. + +To manipulate a specific worker, you pass the -w option to synctl: + + synctl -w $CONFIG/workers/worker1.yaml restart diff --git a/docs/systemd-with-workers/system/matrix-synapse-worker@.service b/docs/systemd-with-workers/system/matrix-synapse-worker@.service index 39bc5e88e8..cb5ac0ac87 100644 --- a/docs/systemd-with-workers/system/matrix-synapse-worker@.service +++ b/docs/systemd-with-workers/system/matrix-synapse-worker@.service @@ -1,9 +1,14 @@ [Unit] Description=Synapse %i AssertPathExists=/etc/matrix-synapse/workers/%i.yaml + # This service should be restarted when the synapse target is restarted. PartOf=matrix-synapse.target +# if this is started at the same time as the main, let the main process start +# first, to initialise the database schema. +After=matrix-synapse.service + [Service] Type=notify NotifyAccess=main diff --git a/docs/systemd-with-workers/workers/federation_reader.yaml b/docs/systemd-with-workers/workers/federation_reader.yaml index 5b65c7040d..13e69e62c9 100644 --- a/docs/systemd-with-workers/workers/federation_reader.yaml +++ b/docs/systemd-with-workers/workers/federation_reader.yaml @@ -1,7 +1,7 @@ worker_app: synapse.app.federation_reader +worker_name: federation_reader1 worker_replication_host: 127.0.0.1 -worker_replication_port: 9092 worker_replication_http_port: 9093 worker_listeners: diff --git a/docs/user_directory.md b/docs/user_directory.md index 37dc71e751..872fc21979 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) +solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql) and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. diff --git a/docs/workers.md b/docs/workers.md index 7512eff43a..bfec745897 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -1,10 +1,10 @@ # Scaling synapse via workers -For small instances it recommended to run Synapse in monolith mode (the -default). For larger instances where performance is a concern it can be helpful -to split out functionality into multiple separate python processes. These -processes are called 'workers', and are (eventually) intended to scale -horizontally independently. +For small instances it recommended to run Synapse in the default monolith mode. +For larger instances where performance is a concern it can be helpful to split +out functionality into multiple separate python processes. These processes are +called 'workers', and are (eventually) intended to scale horizontally +independently. Synapse's worker support is under active development and subject to change as we attempt to rapidly scale ever larger Synapse instances. However we are @@ -16,69 +16,123 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only be used for demo purposes and any admin considering workers should already be running PostgreSQL. -## Master/worker communication +## Main process/worker communication + +The processes communicate with each other via a Synapse-specific protocol called +'replication' (analogous to MySQL- or Postgres-style database replication) which +feeds streams of newly written data between processes so they can be kept in +sync with the database state. + +When configured to do so, Synapse uses a +[Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication +stream between all configured Synapse processes. Additionally, processes may +make HTTP requests to each other, primarily for operations which need to wait +for a reply ─ such as sending an event. + +Redis support was added in v1.13.0 with it becoming the recommended method in +v1.18.0. It replaced the old direct TCP connections (which is deprecated as of +v1.18.0) to the main process. With Redis, rather than all the workers connecting +to the main process, all the workers and the main process connect to Redis, +which relays replication commands between processes. This can give a significant +cpu saving on the main process and will be a prerequisite for upcoming +performance improvements. + +See the [Architectural diagram](#architectural-diagram) section at the end for +a visualisation of what this looks like. + -The workers communicate with the master process via a Synapse-specific protocol -called 'replication' (analogous to MySQL- or Postgres-style database -replication) which feeds a stream of relevant data from the master to the -workers so they can be kept in sync with the master process and database state. +## Setting up workers -Additionally, workers may make HTTP requests to the master, to send information -in the other direction. Typically this is used for operations which need to -wait for a reply - such as sending an event. +A Redis server is required to manage the communication between the processes. +The Redis server should be installed following the normal procedure for your +distribution (e.g. `apt install redis-server` on Debian). It is safe to use an +existing Redis deployment if you have one. -## Configuration +Once installed, check that Redis is running and accessible from the host running +Synapse, for example by executing `echo PING | nc -q1 localhost 6379` and seeing +a response of `+PONG`. + +The appropriate dependencies must also be installed for Synapse. If using a +virtualenv, these can be installed with: + +```sh +pip install matrix-synapse[redis] +``` + +Note that these dependencies are included when synapse is installed with `pip +install matrix-synapse[all]`. They are also included in the debian packages from +`matrix.org` and in the docker images at +https://hub.docker.com/r/matrixdotorg/synapse/. To make effective use of the workers, you will need to configure an HTTP reverse-proxy such as nginx or haproxy, which will direct incoming requests to -the correct worker, or to the main synapse instance. Note that this includes -requests made to the federation port. See [reverse_proxy.md](reverse_proxy.md) -for information on setting up a reverse proxy. +the correct worker, or to the main synapse instance. See +[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse +proxy. + +When using workers, each worker process has its own configuration file which +contains settings specific to that worker, such as the HTTP listener that it +provides (if any), logging configuration, etc. + +Normally, the worker processes are configured to read from a shared +configuration file as well as the worker-specific configuration files. This +makes it easier to keep common configuration settings synchronised across all +the processes. + +The main process is somewhat special in this respect: it does not normally +need its own configuration file and can take all of its configuration from the +shared configuration file. + + +### Shared configuration + +Normally, only a couple of changes are needed to make an existing configuration +file suitable for use with workers. First, you need to enable an "HTTP replication +listener" for the main process; and secondly, you need to enable redis-based +replication. For example: -To enable workers, you need to add *two* replication listeners to the -main Synapse configuration file (`homeserver.yaml`). For example: ```yaml +# extend the existing `listeners` section. This defines the ports that the +# main process will listen on. listeners: - # The TCP replication port - - port: 9092 - bind_address: '127.0.0.1' - type: replication - # The HTTP replication port - port: 9093 bind_address: '127.0.0.1' type: http resources: - names: [replication] + +redis: + enabled: true ``` -Under **no circumstances** should these replication API listeners be exposed to -the public internet; they have no authentication and are unencrypted. +See the sample config for the full documentation of each option. + +Under **no circumstances** should the replication listener be exposed to the +public internet; it has no authentication and is unencrypted. -You should then create a set of configs for the various worker processes. Each -worker configuration file inherits the configuration of the main homeserver -configuration file. You can then override configuration specific to that -worker, e.g. the HTTP listener that it provides (if any); logging -configuration; etc. You should minimise the number of overrides though to -maintain a usable config. + +### Worker configuration In the config file for each worker, you must specify the type of worker -application (`worker_app`). The currently available worker applications are -listed below. You must also specify the replication endpoints that it should -talk to on the main synapse process. `worker_replication_host` should specify -the host of the main synapse, `worker_replication_port` should point to the TCP -replication listener port and `worker_replication_http_port` should point to -the HTTP replication port. +application (`worker_app`), and you should specify a unqiue name for the worker +(`worker_name`). The currently available worker applications are listed below. +You must also specify the HTTP replication endpoint that it should talk to on +the main synapse process. `worker_replication_host` should specify the host of +the main synapse and `worker_replication_http_port` should point to the HTTP +replication port. If the worker will handle HTTP requests then the +`worker_listeners` option should be set with a `http` listener, in the same way +as the `listeners` option in the shared config. For example: ```yaml -worker_app: synapse.app.synchrotron +worker_app: synapse.app.generic_worker +worker_name: worker1 -# The replication listener on the synapse to talk to. +# The replication listener on the main synapse process. worker_replication_host: 127.0.0.1 -worker_replication_port: 9092 worker_replication_http_port: 9093 worker_listeners: @@ -87,142 +141,43 @@ worker_listeners: resources: - names: - client + - federation -worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml +worker_log_config: /home/matrix/synapse/config/worker1_log_config.yaml ``` -...is a full configuration for a synchrotron worker instance, which will expose a -plain HTTP `/sync` endpoint on port 8083 separately from the `/sync` endpoint provided -by the main synapse. +...is a full configuration for a generic worker instance, which will expose a +plain HTTP endpoint on port 8083 separately serving various endpoints, e.g. +`/sync`, which are listed below. Obviously you should configure your reverse-proxy to route the relevant endpoints to the worker (`localhost:8083` in the above example). + +### Running Synapse with workers + Finally, you need to start your worker processes. This can be done with either `synctl` or your distribution's preferred service manager such as `systemd`. We recommend the use of `systemd` where available: for information on setting up `systemd` to start synapse workers, see -[systemd-with-workers](systemd-with-workers). To use `synctl`, see below. +[systemd-with-workers](systemd-with-workers). To use `synctl`, see +[synctl_workers.md](synctl_workers.md). -### **Experimental** support for replication over redis - -As of Synapse v1.13.0, it is possible to configure Synapse to send replication -via a [Redis pub/sub channel](https://redis.io/topics/pubsub). This is an -alternative to direct TCP connections to the master: rather than all the -workers connecting to the master, all the workers and the master connect to -Redis, which relays replication commands between processes. This can give a -significant cpu saving on the master and will be a prerequisite for upcoming -performance improvements. - -Note that this support is currently experimental; you may experience lost -messages and similar problems! It is strongly recommended that admins setting -up workers for the first time use direct TCP replication as above. - -To configure Synapse to use Redis: - -1. Install Redis following the normal procedure for your distribution - for - example, on Debian, `apt install redis-server`. (It is safe to use an - existing Redis deployment if you have one: we use a pub/sub stream named - according to the `server_name` of your synapse server.) -2. Check Redis is running and accessible: you should be able to `echo PING | nc -q1 - localhost 6379` and get a response of `+PONG`. -3. Install the python prerequisites. If you installed synapse into a - virtualenv, this can be done with: - ```sh - pip install matrix-synapse[redis] - ``` - The debian packages from matrix.org already include the required - dependencies. -4. Add config to the shared configuration (`homeserver.yaml`): - ```yaml - redis: - enabled: true - ``` - Optional parameters which can go alongside `enabled` are `host`, `port`, - `password`. Normally none of these are required. -5. Restart master and all workers. - -Once redis replication is in use, `worker_replication_port` is redundant and -can be removed from the worker configuration files. Similarly, the -configuration for the `listener` for the TCP replication port can be removed -from the main configuration file. Note that the HTTP replication port is -still required. - -### Using synctl - -If you want to use `synctl` to manage your synapse processes, you will need to -create an an additional configuration file for the master synapse process. That -configuration should look like this: - -```yaml -worker_app: synapse.app.homeserver -``` - -Additionally, each worker app must be configured with the name of a "pid file", -to which it will write its process ID when it starts. For example, for a -synchrotron, you might write: - -```yaml -worker_pid_file: /home/matrix/synapse/synchrotron.pid -``` - -Finally, to actually run your worker-based synapse, you must pass synctl the `-a` -commandline option to tell it to operate on all the worker configurations found -in the given directory, e.g.: - - synctl -a $CONFIG/workers start - -Currently one should always restart all workers when restarting or upgrading -synapse, unless you explicitly know it's safe not to. For instance, restarting -synapse without restarting all the synchrotrons may result in broken typing -notifications. - -To manipulate a specific worker, you pass the -w option to synctl: - - synctl -w $CONFIG/workers/synchrotron.yaml restart ## Available worker applications -### `synapse.app.pusher` - -Handles sending push notifications to sygnal and email. Doesn't handle any -REST endpoints itself, but you should set `start_pushers: False` in the -shared configuration file to stop the main synapse sending these notifications. - -Note this worker cannot be load-balanced: only one instance should be active. - -### `synapse.app.synchrotron` +### `synapse.app.generic_worker` -The synchrotron handles `sync` requests from clients. In particular, it can -handle REST endpoints matching the following regular expressions: +This worker can handle API requests matching the following regular +expressions: + # Sync requests ^/_matrix/client/(v2_alpha|r0)/sync$ ^/_matrix/client/(api/v1|v2_alpha|r0)/events$ ^/_matrix/client/(api/v1|r0)/initialSync$ ^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$ -The above endpoints should all be routed to the synchrotron worker by the -reverse-proxy configuration. - -It is possible to run multiple instances of the synchrotron to scale -horizontally. In this case the reverse-proxy should be configured to -load-balance across the instances, though it will be more efficient if all -requests from a particular user are routed to a single instance. Extracting -a userid from the access token is currently left as an exercise for the reader. - -### `synapse.app.appservice` - -Handles sending output traffic to Application Services. Doesn't handle any -REST endpoints itself, but you should set `notify_appservices: False` in the -shared configuration file to stop the main synapse sending these notifications. - -Note this worker cannot be load-balanced: only one instance should be active. - -### `synapse.app.federation_reader` - -Handles a subset of federation endpoints. In particular, it can handle REST -endpoints matching the following regular expressions: - + # Federation requests ^/_matrix/federation/v1/event/ ^/_matrix/federation/v1/state/ ^/_matrix/federation/v1/state_ids/ @@ -242,40 +197,145 @@ endpoints matching the following regular expressions: ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ - ^/_matrix/federation/v1/send/ ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query + # Inbound federation transaction request + ^/_matrix/federation/v1/send/ + + # Client API requests + ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$ + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$ + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$ + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$ + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$ + ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$ + ^/_matrix/client/(api/v1|r0|unstable)/keys/query$ + ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$ + ^/_matrix/client/versions$ + ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$ + ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ + ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ + ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ + + # Registration/login requests + ^/_matrix/client/(api/v1|r0|unstable)/login$ + ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ + + # Event sending requests + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ + ^/_matrix/client/(api/v1|r0|unstable)/join/ + ^/_matrix/client/(api/v1|r0|unstable)/profile/ + + Additionally, the following REST endpoints can be handled for GET requests: ^/_matrix/federation/v1/groups/ -The above endpoints should all be routed to the federation_reader worker by the -reverse-proxy configuration. +Pagination requests can also be handled, but all requests for a given +room must be routed to the same instance. Additionally, care must be taken to +ensure that the purge history admin API is not used while pagination requests +for the room are in flight: + + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$ + +Note that a HTTP listener with `client` and `federation` resources must be +configured in the `worker_listeners` option in the worker config. + + +#### Load balancing + +It is possible to run multiple instances of this worker app, with incoming requests +being load-balanced between them by the reverse-proxy. However, different endpoints +have different characteristics and so admins +may wish to run multiple groups of workers handling different endpoints so that +load balancing can be done in different ways. + +For `/sync` and `/initialSync` requests it will be more efficient if all +requests from a particular user are routed to a single instance. Extracting a +user ID from the access token or `Authorization` header is currently left as an +exercise for the reader. Admins may additionally wish to separate out `/sync` +requests that have a `since` query parameter from those that don't (and +`/initialSync`), as requests that don't are known as "initial sync" that happens +when a user logs in on a new device and can be *very* resource intensive, so +isolating these requests will stop them from interfering with other users ongoing +syncs. + +Federation and client requests can be balanced via simple round robin. + +The inbound federation transaction request `^/_matrix/federation/v1/send/` +should be balanced by source IP so that transactions from the same remote server +go to the same process. + +Registration/login requests can be handled separately purely to help ensure that +unexpected load doesn't affect new logins and sign ups. + +Finally, event sending requests can be balanced by the room ID in the URI (or +the full URI, or even just round robin), the room ID is the path component after +`/rooms/`. If there is a large bridge connected that is sending or may send lots +of events, then a dedicated set of workers can be provisioned to limit the +effects of bursts of events from that bridge on events sent by normal users. -The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single -instance. +#### Stream writers -Note that `federation` must be added to the listener resources in the worker config: +Additionally, there is *experimental* support for moving writing of specific +streams (such as events) off of the main process to a particular worker. (This +is only supported with Redis-based replication.) + +Currently support streams are `events` and `typing`. + +To enable this, the worker must have a HTTP replication listener configured, +have a `worker_name` and be listed in the `instance_map` config. For example to +move event persistence off to a dedicated worker, the shared configuration would +include: ```yaml -worker_app: synapse.app.federation_reader -... -worker_listeners: - - type: http - port: <port> - resources: - - names: - - federation +instance_map: + event_persister1: + host: localhost + port: 8034 + +stream_writers: + events: event_persister1 ``` + +### `synapse.app.pusher` + +Handles sending push notifications to sygnal and email. Doesn't handle any +REST endpoints itself, but you should set `start_pushers: False` in the +shared configuration file to stop the main synapse sending push notifications. + +Note this worker cannot be load-balanced: only one instance should be active. + +### `synapse.app.appservice` + +Handles sending output traffic to Application Services. Doesn't handle any +REST endpoints itself, but you should set `notify_appservices: False` in the +shared configuration file to stop the main synapse sending appservice notifications. + +Note this worker cannot be load-balanced: only one instance should be active. + + ### `synapse.app.federation_sender` Handles sending federation traffic to other servers. Doesn't handle any REST endpoints itself, but you should set `send_federation: False` in the shared configuration file to stop the main synapse sending this traffic. -Note this worker cannot be load-balanced: only one instance should be active. +If running multiple federation senders then you must list each +instance in the `federation_sender_instances` option by their `worker_name`. +All instances must be stopped and started when adding or removing instances. +For example: + +```yaml +federation_sender_instances: + - federation_sender1 + - federation_sender2 +``` ### `synapse.app.media_repository` @@ -307,47 +367,12 @@ expose the `media` resource. For example: - media ``` -Note this worker cannot be load-balanced: only one instance should be active. - -### `synapse.app.client_reader` - -Handles client API endpoints. It can handle REST endpoints matching the -following regular expressions: +Note that if running multiple media repositories they must be on the same server +and you must configure a single instance to run the background tasks, e.g.: - ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$ - ^/_matrix/client/(api/v1|r0|unstable)/login$ - ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$ - ^/_matrix/client/(api/v1|r0|unstable)/keys/query$ - ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$ - ^/_matrix/client/versions$ - ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$ - ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ - ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ - ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ - -Additionally, the following REST endpoints can be handled for GET requests: - - ^/_matrix/client/(api/v1|r0|unstable)/pushrules/.*$ - ^/_matrix/client/(api/v1|r0|unstable)/groups/.*$ - ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/account_data/ - ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/rooms/[^/]*/account_data/ - -Additionally, the following REST endpoints can be handled, but all requests must -be routed to the same instance: - - ^/_matrix/client/(r0|unstable)/register$ - ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ - -Pagination requests can also be handled, but all requests with the same path -room must be routed to the same instance. Additionally, care must be taken to -ensure that the purge history admin API is not used while pagination requests -for the room are in flight: - - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$ +```yaml + media_instance_running_background_jobs: "media-repository-1" +``` ### `synapse.app.user_dir` @@ -383,15 +408,65 @@ file. For example: worker_main_http_uri: http://127.0.0.1:8008 -### `synapse.app.event_creator` +### Historical apps -Handles some event creation. It can handle REST endpoints matching: +*Note:* Historically there used to be more apps, however they have been +amalgamated into a single `synapse.app.generic_worker` app. The remaining apps +are ones that do specific processing unrelated to requests, e.g. the `pusher` +that handles sending out push notifications for new events. The intention is for +all these to be folded into the `generic_worker` app and to use config to define +which processes handle the various proccessing such as push notifications. - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ - ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ - ^/_matrix/client/(api/v1|r0|unstable)/join/ - ^/_matrix/client/(api/v1|r0|unstable)/profile/ -It will create events locally and then send them on to the main synapse -instance to be persisted and handled. +## Migration from old config + +There are two main independent changes that have been made: introducing Redis +support and merging apps into `synapse.app.generic_worker`. Both these changes +are backwards compatible and so no changes to the config are required, however +server admins are encouraged to plan to migrate to Redis as the old style direct +TCP replication config is deprecated. + +To migrate to Redis add the `redis` config as above, and optionally remove the +TCP `replication` listener from master and `worker_replication_port` from worker +config. + +To migrate apps to use `synapse.app.generic_worker` simply update the +`worker_app` option in the worker configs, and where worker are started (e.g. +in systemd service files, but not required for synctl). + + +## Architectural diagram + +The following shows an example setup using Redis and a reverse proxy: + +``` + Clients & Federation + | + v + +-----------+ + | | + | Reverse | + | Proxy | + | | + +-----------+ + | | | + | | | HTTP requests + +-------------------+ | +-----------+ + | +---+ | + | | | + v v v ++--------------+ +--------------+ +--------------+ +--------------+ +| Main | | Generic | | Generic | | Event | +| Process | | Worker 1 | | Worker 2 | | Persister | ++--------------+ +--------------+ +--------------+ +--------------+ + ^ ^ | ^ | | ^ | ^ ^ + | | | | | | | | | | + | | | | | HTTP | | | | | + | +----------+<--|---|---------+ | | | | + | | +-------------|-->+----------+ | + | | | | + | | | | + v v v v +==================================================================== + Redis pub/sub channel +``` diff --git a/mypy.ini b/mypy.ini index 3533797d68..7764f17856 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,11 +1,66 @@ [mypy] namespace_packages = True -plugins = mypy_zope:plugin +plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py follow_imports = silent check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +files = + synapse/api, + synapse/appservice, + synapse/config, + synapse/event_auth.py, + synapse/events/builder.py, + synapse/events/spamcheck.py, + synapse/federation, + synapse/handlers/auth.py, + synapse/handlers/cas_handler.py, + synapse/handlers/directory.py, + synapse/handlers/events.py, + synapse/handlers/federation.py, + synapse/handlers/identity.py, + synapse/handlers/initial_sync.py, + synapse/handlers/message.py, + synapse/handlers/oidc_handler.py, + synapse/handlers/pagination.py, + synapse/handlers/presence.py, + synapse/handlers/room.py, + synapse/handlers/room_member.py, + synapse/handlers/room_member_worker.py, + synapse/handlers/saml_handler.py, + synapse/handlers/sync.py, + synapse/handlers/ui_auth, + synapse/http/federation/well_known_resolver.py, + synapse/http/server.py, + synapse/http/site.py, + synapse/logging/, + synapse/metrics, + synapse/module_api, + synapse/notifier.py, + synapse/push/pusherpool.py, + synapse/push/push_rule_evaluator.py, + synapse/replication, + synapse/rest, + synapse/server.py, + synapse/server_notices, + synapse/spam_checker_api, + synapse/state, + synapse/storage/databases/main/stream.py, + synapse/storage/databases/main/ui_auth.py, + synapse/storage/database.py, + synapse/storage/engines, + synapse/storage/state.py, + synapse/storage/util, + synapse/streams, + synapse/types.py, + synapse/util/caches/descriptors.py, + synapse/util/caches/stream_change_cache.py, + synapse/util/metrics.py, + tests/replication, + tests/test_utils, + tests/rest/client/v2_alpha/test_auth.py, + tests/util/test_stream_change_cache.py [mypy-pymacaroons.*] ignore_missing_imports = True @@ -78,3 +133,9 @@ ignore_missing_imports = True [mypy-authlib.*] ignore_missing_imports = True + +[mypy-rust_python_jaeger_reporter.*] +ignore_missing_imports = True + +[mypy-nacl.*] +ignore_missing_imports = True diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index e6f4bd1dca..d055cf3287 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -24,7 +24,6 @@ DISTS = ( "debian:sid", "ubuntu:xenial", "ubuntu:bionic", - "ubuntu:eoan", "ubuntu:focal", ) diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index 98a618f6b2..448cadb829 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -3,6 +3,8 @@ # A script which checks that an appropriate news file has been added on this # branch. +echo -e "+++ \033[32mChecking newsfragment\033[m" + set -e # make sure that origin/develop is up to date @@ -16,6 +18,8 @@ pr="$BUILDKITE_PULL_REQUEST" if ! git diff --quiet FETCH_HEAD... -- debian; then if git diff --quiet FETCH_HEAD... -- debian/changelog; then echo "Updates to debian directory, but no update to the changelog." >&2 + echo "!! Please see the contributing guide for help writing your changelog entry:" >&2 + echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2 exit 1 fi fi @@ -26,7 +30,12 @@ if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then exit 0 fi -tox -qe check-newsfragment +# Print a link to the contributing guide if the user makes a mistake +CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry: +https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" + +# If check-newsfragment returns a non-zero exit code, print the contributing guide and exit +tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) echo echo "--------------------------" @@ -38,6 +47,7 @@ for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do lastchar=`tr -d '\n' < $f | tail -c 1` if [ $lastchar != '.' -a $lastchar != '!' ]; then echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 + echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 exit 1 fi @@ -47,5 +57,6 @@ done if [[ -n "$pr" && "$matched" -eq 0 ]]; then echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2 + echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 exit 1 fi diff --git a/scripts-dev/check_line_terminators.sh b/scripts-dev/check_line_terminators.sh new file mode 100755 index 0000000000..c983956231 --- /dev/null +++ b/scripts-dev/check_line_terminators.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This script checks that line terminators in all repository files (excluding +# those in the .git directory) feature unix line terminators. +# +# Usage: +# +# ./check_line_terminators.sh +# +# The script will emit exit code 1 if any files that do not use unix line +# terminators are found, 0 otherwise. + +# cd to the root of the repository +cd `dirname $0`/.. + +# Find and print files with non-unix line terminators +if find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$'; then + echo -e '\e[31mERROR: found files with CRLF line endings. See above.\e[39m' + exit 1 +fi diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index ecda103cf7..6755bc5282 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -2,9 +2,9 @@ import argparse import json import logging import sys -import urllib2 import dns.resolver +import urllib2 from signedjson.key import decode_verify_key_bytes, write_signing_keys from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 7c19e405d4..ad12523c4d 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -21,11 +21,12 @@ import argparse import base64 import json import sys - -from six.moves.urllib import parse as urlparse +from typing import Any, Optional +from urllib import parse as urlparse import nacl.signing import requests +import signedjson.types import srvlookup import yaml from requests.adapters import HTTPAdapter @@ -70,7 +71,9 @@ def encode_canonical_json(value): ).encode("UTF-8") -def sign_json(json_object, signing_key, signing_name): +def sign_json( + json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str +) -> Any: signatures = json_object.pop("signatures", {}) unsigned = json_object.pop("unsigned", None) @@ -123,7 +126,14 @@ def read_signing_keys(stream): return keys -def request_json(method, origin_name, origin_key, destination, path, content): +def request( + method: Optional[str], + origin_name: str, + origin_key: signedjson.types.SigningKey, + destination: str, + path: str, + content: Optional[str], +) -> requests.Response: if method is None: if content is None: method = "GET" @@ -160,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content): if method == "POST": headers["Content-Type"] = "application/json" - result = s.request( - method=method, url=dest, headers=headers, verify=False, data=content + return s.request( + method=method, + url=dest, + headers=headers, + verify=False, + data=content, + stream=True, ) - sys.stderr.write("Status Code: %d\n" % (result.status_code,)) - return result.json() def main(): @@ -223,7 +236,7 @@ def main(): with open(args.signing_key_path) as f: key = read_signing_keys(f)[0] - result = request_json( + result = request( args.method, args.server_name, key, @@ -232,7 +245,12 @@ def main(): content=args.body, ) - json.dump(result, sys.stdout) + sys.stderr.write("Status Code: %d\n" % (result.status_code,)) + + for chunk in result.iter_content(): + # we write raw utf8 to stdout. + sys.stdout.buffer.write(chunk) + print("") diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index bf3862a386..89acb52e6a 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -15,7 +15,7 @@ from synapse.storage.pdu import PduStore from synapse.storage.signatures import SignatureStore -class Store(object): +class Store: _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 34c4854e11..0647993658 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -2,8 +2,8 @@ # # Runs linting scripts over the local Synapse checkout # isort - sorts import statements -# flake8 - lints and finds mistakes # black - opinionated code formatter +# flake8 - lints and finds mistakes set -e @@ -11,11 +11,11 @@ if [ $# -ge 1 ] then files=$* else - files="synapse tests scripts-dev scripts" + files="synapse tests scripts-dev scripts contrib synctl" fi echo "Linting these locations: $files" -isort -y -rc $files -flake8 $files +isort $files python3 -m black $files ./scripts-dev/config-lint.sh +flake8 $files diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py new file mode 100644 index 0000000000..a5b88731f1 --- /dev/null +++ b/scripts-dev/mypy_synapse_plugin.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This is a mypy plugin for Synpase to deal with some of the funky typing that +can crop up, e.g the cache descriptors. +""" + +from typing import Callable, Optional + +from mypy.plugin import MethodSigContext, Plugin +from mypy.typeops import bind_self +from mypy.types import CallableType + + +class SynapsePlugin(Plugin): + def get_method_signature_hook( + self, fullname: str + ) -> Optional[Callable[[MethodSigContext], CallableType]]: + if fullname.startswith( + "synapse.util.caches.descriptors._CachedFunction.__call__" + ): + return cached_function_method_signature + return None + + +def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: + """Fixes the `_CachedFunction.__call__` signature to be correct. + + It already has *almost* the correct signature, except: + + 1. the `self` argument needs to be marked as "bound"; and + 2. any `cache_context` argument should be removed. + """ + + # First we mark this as a bound function signature. + signature = bind_self(ctx.default_signature) + + # Secondly, we remove any "cache_context" args. + # + # Note: We should be only doing this if `cache_context=True` is set, but if + # it isn't then the code will raise an exception when its called anyway, so + # its not the end of the world. + context_arg_index = None + for idx, name in enumerate(signature.arg_names): + if name == "cache_context": + context_arg_index = idx + break + + if context_arg_index: + arg_types = list(signature.arg_types) + arg_types.pop(context_arg_index) + + arg_names = list(signature.arg_names) + arg_names.pop(context_arg_index) + + arg_kinds = list(signature.arg_kinds) + arg_kinds.pop(context_arg_index) + + signature = signature.copy_modified( + arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, + ) + + return signature + + +def plugin(version: str): + # This is the entry point of the plugin, and let's us deal with the fact + # that the mypy plugin interface is *not* stable by looking at the version + # string. + # + # However, since we pin the version of mypy Synapse uses in CI, we don't + # really care. + return SynapsePlugin diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 94aa8758b4..56365e2b58 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -40,7 +40,7 @@ class MockHomeserver(HomeServer): config.server_name, reactor=reactor, config=config, **kwargs ) - self.version_string = "Synapse/"+get_version_string(synapse) + self.version_string = "Synapse/" + get_version_string(synapse) if __name__ == "__main__": @@ -86,7 +86,7 @@ if __name__ == "__main__": store = hs.get_datastore() async def run_background_updates(): - await store.db.updates.run_background_updates(sleep=False) + await store.db_pool.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 9a0fbc61d8..a34bdf1830 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -23,8 +23,6 @@ import sys import time import traceback -from six import string_types - import yaml from twisted.internet import defer, reactor @@ -37,30 +35,29 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore -from synapse.storage.data_stores.main.deviceinbox import ( - DeviceInboxBackgroundUpdateStore, -) -from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore -from synapse.storage.data_stores.main.events_bg_updates import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore +from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, ) -from synapse.storage.data_stores.main.media_repository import ( +from synapse.storage.databases.main.media_repository import ( MediaRepositoryBackgroundUpdateStore, ) -from synapse.storage.data_stores.main.registration import ( +from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, + find_max_generated_user_id_localpart, ) -from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore -from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore -from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore -from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore -from synapse.storage.data_stores.main.stats import StatsStore -from synapse.storage.data_stores.main.user_directory import ( +from synapse.storage.databases.main.room import RoomBackgroundUpdateStore +from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore +from synapse.storage.databases.main.search import SearchBackgroundUpdateStore +from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.database import Database, make_conn +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -91,6 +88,7 @@ BOOLEAN_COLUMNS = { "account_validity": ["email_sent"], "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], + "local_media_repository": ["safe_from_quarantine"], } @@ -129,6 +127,26 @@ APPEND_ONLY_TABLES = [ ] +IGNORED_TABLES = { + # We don't port these tables, as they're a faff and we can regenerate + # them anyway. + "user_directory", + "user_directory_search", + "user_directory_search_content", + "user_directory_search_docsize", + "user_directory_search_segdir", + "user_directory_search_segments", + "user_directory_search_stat", + "user_directory_search_pos", + "users_who_share_private_rooms", + "users_in_public_room", + # UI auth sessions have foreign keys so additional care needs to be taken, + # the sessions are transient anyway, so ignore them. + "ui_auth_sessions", + "ui_auth_sessions_credentials", +} + + # Error returned by the run function. Used at the top-level part of the script to # handle errors and return codes. end_error = None @@ -155,14 +173,14 @@ class Store( StatsStore, ): def execute(self, f, *args, **kwargs): - return self.db.runInteraction(f.__name__, f, *args, **kwargs) + return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.db.runInteraction("execute_sql", r) + return self.db_pool.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -207,7 +225,7 @@ class Porter(object): async def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = await self.postgres_store.db.simple_select_one( + row = await self.postgres_store.db_pool.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -224,7 +242,7 @@ class Porter(object): ) = await self._setup_sent_transactions() backward_chunk = 0 else: - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -254,7 +272,7 @@ class Porter(object): await self.postgres_store.execute(delete_all) - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -291,21 +309,14 @@ class Porter(object): ) return - if table in ( - "user_directory", - "user_directory_search", - "users_who_share_rooms", - "users_in_pubic_room", - ): - # We don't port these tables, as they're a faff and we can regenreate - # them anyway. + if table in IGNORED_TABLES: self.progress.update(table, table_size) # Mark table as done return if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -346,7 +357,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = await self.sqlite_store.db.runInteraction( + headers, frows, brows = await self.sqlite_store.db_pool.runInteraction( "select", r ) @@ -362,7 +373,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store.db.simple_update_one_txn( + self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -400,7 +411,7 @@ class Porter(object): return headers, rows - headers, rows = await self.sqlite_store.db.runInteraction("select", r) + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -438,7 +449,7 @@ class Porter(object): ], ) - self.postgres_store.db.simple_update_one_txn( + self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -481,7 +492,7 @@ class Porter(object): db_conn, allow_outdated_version=allow_outdated_version ) prepare_database(db_conn, engine, config=self.hs_config) - store = Store(Database(hs, db_config, engine), db_conn, hs) + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) db_conn.commit() return store @@ -489,7 +500,7 @@ class Porter(object): async def run_background_updates_on_postgres(self): # Manually apply all background updates on the PostgreSQL database. postgres_ready = ( - await self.postgres_store.db.updates.has_completed_background_updates() + await self.postgres_store.db_pool.updates.has_completed_background_updates() ) if not postgres_ready: @@ -498,9 +509,9 @@ class Porter(object): self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - await self.postgres_store.db.updates.do_next_background_update(100) + await self.postgres_store.db_pool.updates.do_next_background_update(100) postgres_ready = await ( - self.postgres_store.db.updates.has_completed_background_updates() + self.postgres_store.db_pool.updates.has_completed_background_updates() ) async def run(self): @@ -521,7 +532,7 @@ class Porter(object): # Check if all background updates are done, abort if not. updates_complete = ( - await self.sqlite_store.db.updates.has_completed_background_updates() + await self.sqlite_store.db_pool.updates.has_completed_background_updates() ) if not updates_complete: end_error = ( @@ -563,22 +574,24 @@ class Porter(object): ) try: - await self.postgres_store.db.runInteraction("alter_table", alter_table) + await self.postgres_store.db_pool.runInteraction( + "alter_table", alter_table + ) except Exception: # On Error Resume Next pass - await self.postgres_store.db.runInteraction( + await self.postgres_store.db_pool.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = await self.sqlite_store.db.simple_select_onecol( + sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = await self.postgres_store.db.simple_select_onecol( + postgres_tables = await self.postgres_store.db_pool.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -610,8 +623,10 @@ class Porter(object): ) ) - # Step 5. Do final post-processing + # Step 5. Set up sequences + self.progress.set_state("Setting up sequence generators") await self._setup_state_group_id_seq() + await self._setup_user_id_seq() self.progress.done() except Exception as e: @@ -635,7 +650,7 @@ class Porter(object): return bool(col) if isinstance(col, bytes): return bytearray(col) - elif isinstance(col, string_types) and "\0" in col: + elif isinstance(col, str) and "\0" in col: logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, @@ -677,7 +692,7 @@ class Porter(object): return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = await self.sqlite_store.db.runInteraction("select", r) + headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -710,7 +725,7 @@ class Porter(object): next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - await self.postgres_store.db.simple_insert( + await self.postgres_store.db_pool.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -779,7 +794,14 @@ class Porter(object): next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) + + def _setup_user_id_seq(self): + def r(txn): + next_id = find_max_generated_user_id_localpart(txn) + 1 + txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) + + return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) ############################################## diff --git a/setup.cfg b/setup.cfg index 12a7849081..a32278ea8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,12 +26,11 @@ ignore=W503,W504,E203,E731,E501 [isort] line_length = 88 -not_skip = __init__.py sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY known_first_party = synapse known_tests=tests -known_compat = mock,six +known_compat = mock known_twisted=twisted,OpenSSL multi_line_output=3 include_trailing_comma=true diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi new file mode 100644 index 0000000000..3f3af59f26 --- /dev/null +++ b/stubs/frozendict.pyi @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stub for frozendict. + +from typing import ( + Any, + Hashable, + Iterable, + Iterator, + Mapping, + overload, + Tuple, + TypeVar, +) + +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. + +class frozendict(Mapping[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + def __getitem__(self, key: _KT) -> _VT: ... + def __contains__(self, key: Any) -> bool: ... + def copy(self, **add_or_replace: Any) -> frozendict: ... + def __iter__(self) -> Iterator[_KT]: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index cac689d4f3..c66413f003 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -22,6 +22,7 @@ class RedisProtocol: def publish(self, channel: str, message: bytes): ... class SubscriberProtocol: + def __init__(self, *args, **kwargs): ... password: Optional[str] def subscribe(self, channels: Union[str, List[str]]): ... def connectionMade(self): ... diff --git a/synapse/__init__.py b/synapse/__init__.py index 4d39996a2e..bf0bf192a5 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,6 +17,7 @@ """ This is a reference implementation of a Matrix homeserver. """ +import json import os import sys @@ -25,6 +26,9 @@ if sys.version_info < (3, 5): print("Synapse requires Python 3.5 or above.") sys.exit(1) +# Twisted and canonicaljson will fail to import when this file is executed to +# get the __version__ during a fresh install. That's OK and subsequent calls to +# actually start Synapse will import these libraries fine. try: from twisted.internet import protocol from twisted.internet.protocol import Factory @@ -36,7 +40,15 @@ try: except ImportError: pass -__version__ = "1.15.1" +# Use the standard library json implementation instead of simplejson. +try: + from canonicaljson import set_json_library + + set_json_library(json) +except ImportError: + pass + +__version__ = "1.20.0rc3" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index d528450c78..55cce2db22 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -23,8 +23,6 @@ import hmac import logging import sys -from six.moves import input - import requests as _requests import yaml diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 06ade25674..75388643ee 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,19 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Optional - -from six import itervalues +from typing import List, Optional, Tuple import pymacaroons from netaddr import IPAddress -from twisted.internet import defer from twisted.web.server import Request -import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth from synapse.api.auth_blocking import AuthBlocking @@ -37,6 +32,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache @@ -62,7 +58,7 @@ class _InvalidMacaroonException(Exception): pass -class Auth(object): +class Auth: """ FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. @@ -83,28 +79,28 @@ class Auth(object): self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key - @defer.inlineCallbacks - def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids() - auth_events_ids = yield self.compute_auth_events( + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ): + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} + auth_events = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check( room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) - @defer.inlineCallbacks - def check_user_in_room( + async def check_user_in_room( self, room_id: str, user_id: str, current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ): + ) -> EventBase: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -122,35 +118,35 @@ class Auth(object): Raises: AuthError if the user is/was not in the room. Returns: - Deferred[Optional[EventBase]]: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield self.state.get_current_state( + member = await self.state.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - membership = member.membership if member else None - if membership == Membership.JOIN: - return member + if member: + membership = member.membership - # XXX this looks totally bogus. Why do we not allow users who have been banned, - # or those who were members previously and have been re-invited? - if allow_departed_users and membership == Membership.LEAVE: - forgot = yield self.store.did_forget(user_id, room_id) - if not forgot: + if membership == Membership.JOIN: return member + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: + forgot = await self.store.did_forget(user_id, room_id) + if not forgot: + return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - @defer.inlineCallbacks - def check_host_in_room(self, room_id, host): + async def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.is_host_joined(room_id, host) + latest_event_ids = await self.store.is_host_joined(room_id, host) return latest_event_ids def can_federate(self, event, auth_events): @@ -161,14 +157,13 @@ class Auth(object): def get_public_keys(self, invite_event): return event_auth.get_public_keys(invite_event) - @defer.inlineCallbacks - def get_user_by_req( + async def get_user_by_req( self, request: Request, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, - ): + ) -> synapse.types.Requester: """ Get a registered user's ID. Args: @@ -181,7 +176,7 @@ class Auth(object): /login will deliver access tokens regardless of expiration. Returns: - defer.Deferred: resolves to a `synapse.types.Requester` object + Resolves to the requester Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -195,14 +190,14 @@ class Auth(object): access_token = self.get_access_token_from_request(request) - user_id, app_service = yield self._get_appservice_user_id(request) + user_id, app_service = await self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) if ip_addr and self._track_appservice_user_ips: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user_id, access_token=access_token, ip=ip_addr, @@ -212,17 +207,18 @@ class Auth(object): return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token( + user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] + shadow_banned = user_info["shadow_banned"] # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) if ( expiration_ts is not None and self.clock.time_msec() >= expiration_ts @@ -236,7 +232,7 @@ class Auth(object): device_id = user_info.get("device_id") if user and access_token and ip_addr: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, @@ -257,13 +253,17 @@ class Auth(object): opentracing.set_tag("device_id", device_id) return synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service + user, + token_id, + is_guest, + shadow_banned, + device_id, + app_service=app_service, ) except KeyError: raise MissingClientTokenError() - @defer.inlineCallbacks - def _get_appservice_user_id(self, request): + async def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) @@ -284,14 +284,13 @@ class Auth(object): if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (yield self.store.get_user_by_id(user_id)): + if not (await self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") return user_id, app_service - @defer.inlineCallbacks - def get_user_by_access_token( + async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ): + ) -> dict: """ Validate access token and get user_id from it Args: @@ -301,9 +300,10 @@ class Auth(object): allow_expired: If False, raises an InvalidClientTokenError if the token is expired Returns: - Deferred[dict]: dict that includes: + dict that includes: `user` (UserID) `is_guest` (bool) + `shadow_banned` (bool) `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: @@ -315,7 +315,7 @@ class Auth(object): if rights == "access": # first look in the database - r = yield self._look_up_user_by_access_token(token) + r = await self._look_up_user_by_access_token(token) if r: valid_until_ms = r["valid_until_ms"] if ( @@ -353,7 +353,7 @@ class Auth(object): # It would of course be much easier to store guest access # tokens in the database as well, but that would break existing # guest tokens. - stored_user = yield self.store.get_user_by_id(user_id) + stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: @@ -363,6 +363,7 @@ class Auth(object): ret = { "user": user, "is_guest": True, + "shadow_banned": False, "token_id": None, # all guests get the same device id "device_id": GUEST_DEVICE_ID, @@ -372,6 +373,7 @@ class Auth(object): ret = { "user": user, "is_guest": False, + "shadow_banned": False, "token_id": None, "device_id": None, } @@ -483,9 +485,8 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - @defer.inlineCallbacks - def _look_up_user_by_access_token(self, token): - ret = yield self.store.get_user_by_access_token(token) + async def _look_up_user_by_access_token(self, token): + ret = await self.store.get_user_by_access_token(token) if not ret: return None @@ -496,6 +497,7 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "shadow_banned": ret.get("shadow_banned"), "device_id": ret.get("device_id"), "valid_until_ms": ret.get("valid_until_ms"), } @@ -508,7 +510,7 @@ class Auth(object): logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender - return defer.succeed(service) + return service async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. @@ -523,7 +525,7 @@ class Auth(object): def compute_auth_events( self, event, current_state_ids: StateMap[str], for_verification: bool = False, - ): + ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. @@ -531,16 +533,16 @@ class Auth(object): should be added to the event's `auth_events`. Returns: - defer.Deferred(list[str]): List of event IDs. + List of event IDs. """ if event.type == EventTypes.Create: - return defer.succeed([]) + return [] # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publically joinable). Dropping event IDs has the + # when room is publicly joinable). Dropping event IDs has the # advantage that the auth chain for the room grows slower, but we use # the auth chain in state resolution v2 to order events, which means # care must be taken if dropping events to ensure that it doesn't @@ -554,7 +556,7 @@ class Auth(object): if auth_ev_id: auth_ids.append(auth_ev_id) - return defer.succeed(auth_ids) + return auth_ids async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the @@ -637,10 +639,9 @@ class Auth(object): return query_params[0].decode("ascii") - @defer.inlineCallbacks - def check_user_in_room_or_world_readable( + async def check_user_in_room_or_world_readable( self, room_id: str, user_id: str, allow_departed_users: bool = False - ): + ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. @@ -651,10 +652,9 @@ class Auth(object): members but have now departed Returns: - Deferred[tuple[str, str|None]]: Resolves to the current membership of - the user in the room and the membership event ID of the user. If - the user is not in the room and never has been, then - `(Membership.JOIN, None)` is returned. + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. """ try: @@ -663,12 +663,12 @@ class Auth(object): # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_in_room( + member_event = await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state.get_current_state( + visibility = await self.state.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 5c499b6b4e..d8fafd7cb8 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved @@ -24,7 +22,7 @@ from synapse.config.server import is_threepid_reserved logger = logging.getLogger(__name__) -class AuthBlocking(object): +class AuthBlocking: def __init__(self, hs): self.store = hs.get_datastore() @@ -36,8 +34,7 @@ class AuthBlocking(object): self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag @@ -60,7 +57,7 @@ class AuthBlocking(object): if user_id is not None: if user_id == self._server_notices_mxid: return - if (yield self.store.is_support_user(user_id)): + if await self.store.is_support_user(user_id): return if self._hs_disabled: @@ -76,11 +73,11 @@ class AuthBlocking(object): # If the user is already part of the MAU cohort or a trial user if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) + timestamp = await self.store.user_last_seen_monthly_active(user_id) if timestamp: return - is_trial = yield self.store.is_trial_user(user_id) + is_trial = await self.store.is_trial_user(user_id) if is_trial: return elif threepid: @@ -93,7 +90,7 @@ class AuthBlocking(object): # allow registration. Support users are excluded from MAU checks. return # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() + current_mau = await self.store.get_monthly_active_count() if current_mau >= self._max_mau_value: raise ResourceLimitError( 403, diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 5ec4a77ccd..46013cde15 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -28,7 +28,7 @@ MAX_ALIAS_LENGTH = 255 MAX_USERID_LENGTH = 255 -class Membership(object): +class Membership: """Represents the membership states of a user in a room.""" @@ -40,7 +40,7 @@ class Membership(object): LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) -class PresenceState(object): +class PresenceState: """Represents the presence state of a user.""" OFFLINE = "offline" @@ -48,14 +48,14 @@ class PresenceState(object): ONLINE = "online" -class JoinRules(object): +class JoinRules: PUBLIC = "public" KNOCK = "knock" INVITE = "invite" PRIVATE = "private" -class LoginType(object): +class LoginType: PASSWORD = "m.login.password" EMAIL_IDENTITY = "m.login.email.identity" MSISDN = "m.login.msisdn" @@ -65,7 +65,7 @@ class LoginType(object): DUMMY = "m.login.dummy" -class EventTypes(object): +class EventTypes: Member = "m.room.member" Create = "m.room.create" Tombstone = "m.room.tombstone" @@ -96,17 +96,17 @@ class EventTypes(object): Presence = "m.presence" -class RejectedReason(object): +class RejectedReason: AUTH_ERROR = "auth_error" -class RoomCreationPreset(object): +class RoomCreationPreset: PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" -class ThirdPartyEntityKind(object): +class ThirdPartyEntityKind: USER = "user" LOCATION = "location" @@ -115,7 +115,7 @@ ServerNoticeMsgType = "m.server_notice" ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" -class UserTypes(object): +class UserTypes: """Allows for user type specific behaviour. With the benefit of hindsight 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ @@ -125,7 +125,7 @@ class UserTypes(object): ALL_USER_TYPES = (SUPPORT, BOT) -class RelationTypes(object): +class RelationTypes: """The types of relations known to this server. """ @@ -134,14 +134,14 @@ class RelationTypes(object): REFERENCE = "m.reference" -class LimitBlockingTypes(object): +class LimitBlockingTypes: """Reasons that a server may be blocked""" MONTHLY_ACTIVE_USER = "monthly_active_user" HS_DISABLED = "hs_disabled" -class EventContentFields(object): +class EventContentFields: """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 @@ -150,3 +150,8 @@ class EventContentFields(object): # Timestamp to delete the event after # cf https://github.com/matrix-org/matrix-doc/pull/2228 SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + + +class RoomEncryptionAlgorithms: + MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" + DEFAULT = MEGOLM_V1_AES_SHA2 diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d54dfb385d..94a9e58eae 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,19 +17,21 @@ """Contains exceptions and error codes.""" import logging -from typing import Dict, List +import typing +from http import HTTPStatus +from typing import Dict, List, Optional, Union -from six import iteritems -from six.moves import http_client +from twisted.web import http -from canonicaljson import json +from synapse.util import json_decoder -from twisted.web import http +if typing.TYPE_CHECKING: + from synapse.types import JsonDict logger = logging.getLogger(__name__) -class Codes(object): +class Codes: UNRECOGNIZED = "M_UNRECOGNIZED" UNAUTHORIZED = "M_UNAUTHORIZED" FORBIDDEN = "M_FORBIDDEN" @@ -80,11 +82,11 @@ class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. Attributes: - code (int): HTTP error code - msg (str): string describing the error + code: HTTP error code + msg: string describing the error """ - def __init__(self, code, msg): + def __init__(self, code: Union[int, HTTPStatus], msg: str): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) # Some calls to this method pass instances of http.HTTPStatus for `code`. @@ -125,16 +127,16 @@ class SynapseError(CodeMessageException): message (as well as an HTTP status code). Attributes: - errcode (str): Matrix error code e.g 'M_FORBIDDEN' + errcode: Matrix error code e.g 'M_FORBIDDEN' """ - def __init__(self, code, msg, errcode=Codes.UNKNOWN): + def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN): """Constructs a synapse error. Args: - code (int): The integer error code (an HTTP response code) - msg (str): The human-readable error message. - errcode (str): The matrix error code e.g 'M_FORBIDDEN' + code: The integer error code (an HTTP response code) + msg: The human-readable error message. + errcode: The matrix error code e.g 'M_FORBIDDEN' """ super(SynapseError, self).__init__(code, msg) self.errcode = errcode @@ -147,10 +149,16 @@ class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. Attributes: - errcode (str): Matrix error code e.g 'M_FORBIDDEN' + errcode: Matrix error code e.g 'M_FORBIDDEN' """ - def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): + def __init__( + self, + code: int, + msg: str, + errcode: str = Codes.UNKNOWN, + additional_fields: Optional[Dict] = None, + ): super(ProxiedRequestError, self).__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} # type: Dict @@ -166,15 +174,15 @@ class ConsentNotGivenError(SynapseError): privacy policy. """ - def __init__(self, msg, consent_uri): + def __init__(self, msg: str, consent_uri: str): """Constructs a ConsentNotGivenError Args: - msg (str): The human-readable error message - consent_url (str): The URL where the user can give their consent + msg: The human-readable error message + consent_url: The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( - code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN + code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri @@ -187,14 +195,14 @@ class UserDeactivatedError(SynapseError): authenticated endpoint, but the account has been deactivated. """ - def __init__(self, msg): + def __init__(self, msg: str): """Constructs a UserDeactivatedError Args: - msg (str): The human-readable error message + msg: The human-readable error message """ super(UserDeactivatedError, self).__init__( - code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED + code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED ) @@ -203,16 +211,16 @@ class FederationDeniedError(SynapseError): is not on its federation whitelist. Attributes: - destination (str): The destination which has been denied + destination: The destination which has been denied """ - def __init__(self, destination): + def __init__(self, destination: Optional[str]): """Raised by federation client or server to indicate that we are are deliberately not attempting to contact a given server because it is not on our federation whitelist. Args: - destination (str): the domain in question + destination: the domain in question """ self.destination = destination @@ -230,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception): (This indicates we should return a 401 with 'result' as the body) Attributes: - result (dict): the server response to the request, which should be + session_id: The ID of the ongoing interactive auth session. + result: the server response to the request, which should be passed back to the client """ - def __init__(self, result): + def __init__(self, session_id: str, result: "JsonDict"): super(InteractiveAuthIncompleteError, self).__init__( "Interactive auth not yet complete" ) + self.session_id = session_id self.result = result @@ -247,7 +257,6 @@ class UnrecognizedRequestError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.UNRECOGNIZED - message = None if len(args) == 0: message = "Unrecognized request" else: @@ -258,7 +267,7 @@ class UnrecognizedRequestError(SynapseError): class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" - def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): + def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND): super(NotFoundError, self).__init__(404, msg, errcode=errcode) @@ -284,21 +293,23 @@ class InvalidClientCredentialsError(SynapseError): M_UNKNOWN_TOKEN respectively. """ - def __init__(self, msg, errcode): + def __init__(self, msg: str, errcode: str): super().__init__(code=401, msg=msg, errcode=errcode) class MissingClientTokenError(InvalidClientCredentialsError): """Raised when we couldn't find the access token in a request""" - def __init__(self, msg="Missing access token"): + def __init__(self, msg: str = "Missing access token"): super().__init__(msg=msg, errcode="M_MISSING_TOKEN") class InvalidClientTokenError(InvalidClientCredentialsError): """Raised when we didn't understand the access token in a request""" - def __init__(self, msg="Unrecognised access token", soft_logout=False): + def __init__( + self, msg: str = "Unrecognised access token", soft_logout: bool = False + ): super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") self._soft_logout = soft_logout @@ -316,11 +327,11 @@ class ResourceLimitError(SynapseError): def __init__( self, - code, - msg, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact=None, - limit_type=None, + code: int, + msg: str, + errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact: Optional[str] = None, + limit_type: Optional[str] = None, ): self.admin_contact = admin_contact self.limit_type = limit_type @@ -368,10 +379,10 @@ class StoreError(SynapseError): class InvalidCaptchaError(SynapseError): def __init__( self, - code=400, - msg="Invalid captcha.", - error_url=None, - errcode=Codes.CAPTCHA_INVALID, + code: int = 400, + msg: str = "Invalid captcha.", + error_url: Optional[str] = None, + errcode: str = Codes.CAPTCHA_INVALID, ): super(InvalidCaptchaError, self).__init__(code, msg, errcode) self.error_url = error_url @@ -386,10 +397,10 @@ class LimitExceededError(SynapseError): def __init__( self, - code=429, - msg="Too Many Requests", - retry_after_ms=None, - errcode=Codes.LIMIT_EXCEEDED, + code: int = 429, + msg: str = "Too Many Requests", + retry_after_ms: Optional[int] = None, + errcode: str = Codes.LIMIT_EXCEEDED, ): super(LimitExceededError, self).__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms @@ -402,10 +413,10 @@ class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store """ - def __init__(self, current_version): + def __init__(self, current_version: str): """ Args: - current_version (str): the current version of the store they should have used + current_version: the current version of the store they should have used """ super(RoomKeysVersionError, self).__init__( 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION @@ -417,7 +428,7 @@ class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does not support.""" - def __init__(self, msg="Homeserver does not support this room version"): + def __init__(self, msg: str = "Homeserver does not support this room version"): super(UnsupportedRoomVersionError, self).__init__( code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -439,7 +450,7 @@ class IncompatibleRoomVersionError(SynapseError): failing. """ - def __init__(self, room_version): + def __init__(self, room_version: str): super(IncompatibleRoomVersionError, self).__init__( code=400, msg="Your homeserver does not support the features required to " @@ -459,8 +470,8 @@ class PasswordRefusedError(SynapseError): def __init__( self, - msg="This password doesn't comply with the server's policy", - errcode=Codes.WEAK_PASSWORD, + msg: str = "This password doesn't comply with the server's policy", + errcode: str = Codes.WEAK_PASSWORD, ): super(PasswordRefusedError, self).__init__( code=400, msg=msg, errcode=errcode, @@ -485,19 +496,19 @@ class RequestSendFailed(RuntimeError): self.can_retry = can_retry -def cs_error(msg, code=Codes.UNKNOWN, **kwargs): +def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. Args: - msg (str): The error message. - code (str): The error code. - kwargs : Additional keys to add to the response. + msg: The error message. + code: The error code. + kwargs: Additional keys to add to the response. Returns: A dict representing the error response JSON. """ err = {"error": msg, "errcode": code} - for key, value in iteritems(kwargs): + for key, value in kwargs.items(): err[key] = value return err @@ -514,7 +525,14 @@ class FederationError(RuntimeError): is wrong (e.g., it referred to an invalid event) """ - def __init__(self, level, code, reason, affected, source=None): + def __init__( + self, + level: str, + code: int, + reason: str, + affected: str, + source: Optional[str] = None, + ): if level not in ["FATAL", "ERROR", "WARN"]: raise ValueError("Level is not valid: %s" % (level,)) self.level = level @@ -541,16 +559,16 @@ class HttpResponseException(CodeMessageException): Represents an HTTP-level failure of an outbound request Attributes: - response (bytes): body of response + response: body of response """ - def __init__(self, code, msg, response): + def __init__(self, code: int, msg: str, response: bytes): """ Args: - code (int): HTTP status code - msg (str): reason phrase from HTTP response status line - response (bytes): body of response + code: HTTP status code + msg: reason phrase from HTTP response status line + response: body of response """ super(HttpResponseException, self).__init__(code, msg) self.response = response @@ -575,7 +593,7 @@ class HttpResponseException(CodeMessageException): # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text try: - j = json.loads(self.response) + j = json_decoder.decode(self.response.decode("utf-8")) except ValueError: j = {} @@ -586,3 +604,11 @@ class HttpResponseException(CodeMessageException): errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) + + +class ShadowBanError(Exception): + """ + Raised when a shadow-banned user attempts to perform an action. + + This should be caught and a proper "fake" success response sent to the user. + """ diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 8b64d0a285..2a2c9e6f13 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -17,17 +17,13 @@ # limitations under the License. from typing import List -from six import text_type - import jsonschema from canonicaljson import json from jsonschema import FormatChecker -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError -from synapse.storage.presence import UserPresenceState +from synapse.api.presence import UserPresenceState from synapse.types import RoomID, UserID FILTER_SCHEMA = { @@ -134,14 +130,13 @@ def matrix_user_id_validator(user_id_str): return UserID.from_string(user_id_str) -class Filtering(object): +class Filtering: def __init__(self, hs): super(Filtering, self).__init__() self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_user_filter(self, user_localpart, filter_id): - result = yield self.store.get_user_filter(user_localpart, filter_id) + async def get_user_filter(self, user_localpart, filter_id): + result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) def add_user_filter(self, user_localpart, user_filter): @@ -173,7 +168,7 @@ class Filtering(object): raise SynapseError(400, str(e)) -class FilterCollection(object): +class FilterCollection: def __init__(self, filter_json): self._filter_json = filter_json @@ -254,7 +249,7 @@ class FilterCollection(object): ) -class Filter(object): +class Filter: def __init__(self, filter_json): self.filter_json = filter_json @@ -313,7 +308,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes - contains_url = isinstance(content.get("url"), text_type) + contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/storage/presence.py b/synapse/api/presence.py index 18a462f0ee..18a462f0ee 100644 --- a/synapse/storage/presence.py +++ b/synapse/api/presence.py diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ec6b3a69a2..5d9d5a228f 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,10 +17,11 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.types import Requester from synapse.util import Clock -class Ratelimiter(object): +class Ratelimiter: """ Ratelimit actions marked by arbitrary keys. @@ -43,6 +44,42 @@ class Ratelimiter(object): # * The rate_hz of this particular entry. This can vary per request self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + def can_requester_do_action( + self, + requester: Requester, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: + """Can the requester perform the action? + + Args: + requester: The requester to key off when rate limiting. The user property + will be used. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Returns: + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero + """ + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + return self.can_do_action( + requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + ) + def can_do_action( self, key: Any, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index d7baf2bc39..f3ecbf36b6 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -18,7 +18,7 @@ from typing import Dict import attr -class EventFormatVersions(object): +class EventFormatVersions: """This is an internal enum for tracking the version of the event format, independently from the room version. """ @@ -35,20 +35,20 @@ KNOWN_EVENT_FORMAT_VERSIONS = { } -class StateResolutionVersions(object): +class StateResolutionVersions: """Enum to identify the state resolution algorithms""" V1 = 1 # room v1 state res V2 = 2 # MSC1442 state res: room v2 and later -class RoomDisposition(object): +class RoomDisposition: STABLE = "stable" UNSTABLE = "unstable" @attr.s(slots=True, frozen=True) -class RoomVersion(object): +class RoomVersion: """An object which describes the unique attributes of a room version.""" identifier = attr.ib() # str; the identifier for this version @@ -69,7 +69,7 @@ class RoomVersion(object): limit_notifications_power_levels = attr.ib(type=bool) -class RoomVersions(object): +class RoomVersions: V1 = RoomVersion( "1", RoomDisposition.STABLE, diff --git a/synapse/api/urls.py b/synapse/api/urls.py index f34434bd67..bbfccf955e 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -17,8 +17,7 @@ """Contains the URL paths to prefix various aspects of the server with. """ import hmac from hashlib import sha256 - -from six.moves.urllib.parse import urlencode +from urllib.parse import urlencode from synapse.config import ConfigError @@ -34,7 +33,7 @@ MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" -class ConsentURIBuilder(object): +class ConsentURIBuilder: def __init__(self, hs_config): """ Args: diff --git a/synapse/app/_base.py b/synapse/app/_base.py index dedff81af3..fb476ddaf5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import gc import logging import os @@ -20,8 +19,8 @@ import signal import socket import sys import traceback +from typing import Iterable -from daemonize import Daemonize from typing_extensions import NoReturn from twisted.internet import defer, error, reactor @@ -29,9 +28,11 @@ from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error +from synapse.config.server import ListenerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext from synapse.util.async_helpers import Linearizer +from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string @@ -127,17 +128,8 @@ def start_reactor( if print_pidfile: print(pid_file) - daemon = Daemonize( - app=appname, - pid=pid_file, - action=run, - auto_close_fds=False, - verbose=True, - logger=logger, - ) - daemon.start() - else: - run() + daemonize_process(pid_file, logger) + run() def quit_with_error(error_string: str) -> NoReturn: @@ -234,7 +226,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -def start(hs, listeners=None): +def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): """ Start a Synapse server or worker. @@ -245,8 +237,8 @@ def start(hs, listeners=None): notify systemd. Args: - hs (synapse.server.HomeServer) - listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml) + hs: homeserver instance + listeners: Listener configuration ('listeners' in homeserver.yaml) """ try: # Set up the SIGHUP machinery. @@ -276,7 +268,7 @@ def start(hs, listeners=None): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().db.start_profiling() + hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() setup_sentry(hs) @@ -342,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we can run out of file descriptors and infinite loop if we attempt to do too many DNS queries at once + + XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless + you explicitly install twisted.names as the resolver; rather it uses a GAIResolver + backed by the reactor's default threadpool (which is limited to 10 threads). So + (a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't + understand why we would run out of FDs if we did too many lookups at once. + -- richvdh 2020/08/29 """ new_resolver = _LimitedHostnameResolver( reactor.nameResolver, max_dns_requests_in_flight @@ -350,7 +349,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): reactor.installNameResolver(new_resolver) -class _LimitedHostnameResolver(object): +class _LimitedHostnameResolver: """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. """ @@ -410,7 +409,7 @@ class _LimitedHostnameResolver(object): yield deferred -class _DeferredResolutionReceiver(object): +class _DeferredResolutionReceiver: """Wraps a IResolutionReceiver and simply resolves the given deferred when resolution is complete """ diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index a37818fe9a..b6c9085670 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer): pass -@defer.inlineCallbacks -def export_data_command(hs, args): +async def export_data_command(hs, args): """Export data for a user. Args: @@ -91,10 +90,8 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield defer.ensureDeferred( - hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) - ) + res = await hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) ) print(res) @@ -232,14 +229,15 @@ def start(config_options): # We also make sure that `_base.start` gets run before we actually run the # command. - @defer.inlineCallbacks - def run(_reactor): + async def run(): with LoggingContext("command"): - yield _base.start(ss, []) - yield args.func(ss, args) + _base.start(ss, []) + await args.func(ss, args) _base.start_worker_reactor( - "synapse-admin-cmd", config, run_command=lambda: task.react(run) + "synapse-admin-cmd", + config, + run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())), ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f3ec2a34ec..f985810e88 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager -from twisted.internet import defer, reactor +from twisted.internet import address, reactor import synapse import synapse.events @@ -37,6 +37,7 @@ from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging +from synapse.config.server import ListenerConfig from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer from synapse.handlers.presence import ( @@ -86,7 +87,6 @@ from synapse.replication.tcp.streams import ( ReceiptsStream, TagAccountDataStream, ToDeviceStream, - TypingStream, ) from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events @@ -110,6 +110,7 @@ from synapse.rest.client.v1.room import ( RoomSendEventRestServlet, RoomStateEventRestServlet, RoomStateRestServlet, + RoomTypingRestServlet, ) from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha import groups, sync, user_directory @@ -122,17 +123,18 @@ from synapse.rest.client.v2_alpha.account_data import ( from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.server import HomeServer -from synapse.storage.data_stores.main.censor_events import CensorEventsStore -from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.data_stores.main.monthly_active_users import ( +from synapse.server import HomeServer, cache_in_self +from synapse.storage.databases.main.censor_events import CensorEventsStore +from synapse.storage.databases.main.media_repository import MediaRepositoryStore +from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) -from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.data_stores.main.search import SearchWorkerStore -from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.databases.main.presence import UserPresenceState +from synapse.storage.databases.main.search import SearchWorkerStore +from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore +from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -205,10 +207,30 @@ class KeyUploadServlet(RestServlet): if body: # They're actually trying to upload something, proxy to main synapse. - # Pass through the auth headers, if any, in case the access token - # is there. - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) - headers = {"Authorization": auth_headers} + + # Proxy headers from the original request, such as the auth headers + # (in case the access token is there) and the original IP / + # User-Agent of the request. + headers = { + header: request.requestHeaders.getRawHeaders(header, []) + for header in (b"Authorization", b"User-Agent") + } + # Add the previous hop the the X-Forwarded-For header. + x_forwarded_for = request.requestHeaders.getRawHeaders( + b"X-Forwarded-For", [] + ) + if isinstance(request.client, (address.IPv4Address, address.IPv6Address)): + previous_host = request.client.host.encode("ascii") + # If the header exists, add to the comma-separated list of the first + # instance of the header. Otherwise, generate a new header. + if x_forwarded_for: + x_forwarded_for = [ + x_forwarded_for[0] + b", " + previous_host + ] + x_forwarded_for[1:] + else: + x_forwarded_for = [previous_host] + headers[b"X-Forwarded-For"] = x_forwarded_for + try: result = await self.http_client.post_json_get_json( self.main_uri + request.uri.decode("ascii"), body, headers=headers @@ -353,9 +375,8 @@ class GenericWorkerPresence(BasePresenceHandler): return _user_syncing() - @defer.inlineCallbacks - def notify_from_replication(self, states, stream_id): - parties = yield get_interested_parties(self.store, states) + async def notify_from_replication(self, states, stream_id): + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -365,8 +386,7 @@ class GenericWorkerPresence(BasePresenceHandler): users=users_to_states.keys(), ) - @defer.inlineCallbacks - def process_replication_rows(self, token, rows): + async def process_replication_rows(self, token, rows): states = [ UserPresenceState( row.user_id, @@ -384,7 +404,7 @@ class GenericWorkerPresence(BasePresenceHandler): self.user_to_current_state[state.user_id] = state stream_id = token - yield self.notify_from_replication(states, stream_id) + await self.notify_from_replication(states, stream_id) def get_currently_syncing_users_for_replication(self) -> Iterable[str]: return [ @@ -430,37 +450,6 @@ class GenericWorkerPresence(BasePresenceHandler): await self._bump_active_client(user_id=user_id) -class GenericWorkerTyping(object): - def __init__(self, hs): - self._latest_room_serial = 0 - self._reset() - - def _reset(self): - """ - Reset the typing handler's data caches. - """ - # map room IDs to serial numbers - self._room_serials = {} - # map room IDs to sets of users currently typing - self._room_typing = {} - - def process_replication_rows(self, token, rows): - if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. - self._reset() - - # Set the latest serial token to whatever the server gave us. - self._latest_room_serial = token - - for row in rows: - self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = row.user_ids - - def get_current_token(self) -> int: - return self._latest_room_serial - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. @@ -490,37 +479,27 @@ class GenericWorkerSlavedStore( SearchWorkerStore, BaseSlavedStore, ): - def __init__(self, database, db_conn, hs): - super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs) + pass - # We pull out the current federation stream position now so that we - # always have a known value for the federation position in memory so - # that we don't have to bounce via a deferred once when we start the - # replication streams. - self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) - def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" - sql = self.database_engine.convert_param_style(sql) +class GenericWorkerServer(HomeServer): + DATASTORE_CLASS = GenericWorkerSlavedStore - txn = db_conn.cursor() - txn.execute(sql, ("federation",)) - rows = txn.fetchall() - txn.close() + def _listen_http(self, listener_config: ListenerConfig): + port = listener_config.port + bind_addresses = listener_config.bind_addresses - return rows[0][0] if rows else -1 + assert listener_config.http_options is not None + site_tag = listener_config.http_options.tag + if site_tag is None: + site_tag = port -class GenericWorkerServer(HomeServer): - DATASTORE_CLASS = GenericWorkerSlavedStore + # We always include a health resource. + resources = {"/health": HealthResource()} - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: + for res in listener_config.http_options.resources: + for name in res.names: if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "client": @@ -550,6 +529,7 @@ class GenericWorkerServer(HomeServer): KeyUploadServlet(self).register(resource) AccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource) + RoomTypingRestServlet(self).register(resource) sync.register_servlets(self, resource) events.register_servlets(self, resource) @@ -590,7 +570,7 @@ class GenericWorkerServer(HomeServer): " repository is disabled. Ignoring." ) - if name == "openid" and "federation" not in res["names"]: + if name == "openid" and "federation" not in res.names: # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. @@ -625,19 +605,19 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) - def start_listening(self, listeners): + def start_listening(self, listeners: Iterable[ListenerConfig]): for listener in listeners: - if listener["type"] == "http": + if listener.type == "http": self._listen_http(listener) - elif listener["type"] == "manhole": + elif listener.type == "manhole": _base.listen_tcp( - listener["bind_addresses"], - listener["port"], + listener.bind_addresses, + listener.port, manhole( username="matrix", password="rabbithole", globals={"hs": self} ), ) - elif listener["type"] == "metrics": + elif listener.type == "metrics": if not self.get_config().enable_metrics: logger.warning( ( @@ -646,31 +626,29 @@ class GenericWorkerServer(HomeServer): ) ) else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) + _base.listen_metrics(listener.bind_addresses, listener.port) else: - logger.warning("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unsupported listener type: %s", listener.type) self.get_tcp_replication().start_replication(self) - def remove_pusher(self, app_id, push_key, user_id): + async def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_replication_data_handler(self): + @cache_in_self + def get_replication_data_handler(self): return GenericWorkerReplicationHandler(self) - def build_presence_handler(self): + @cache_in_self + def get_presence_handler(self): return GenericWorkerPresence(self) - def build_typing_handler(self): - return GenericWorkerTyping(self) - class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): super(GenericWorkerReplicationHandler, self).__init__(hs) self.store = hs.get_datastore() - self.typing_handler = hs.get_typing_handler() self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence self.notifier = hs.get_notifier() @@ -707,11 +685,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): await self.pusher_pool.on_new_receipts( token, token, {row.room_id for row in rows} ) - elif stream_name == TypingStream.NAME: - self.typing_handler.process_replication_rows(token, rows) - self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows] - ) elif stream_name == ToDeviceStream.NAME: entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: @@ -738,6 +711,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): except Exception: logger.exception("Error processing replication") + async def on_position(self, stream_name: str, instance_name: str, token: int): + await super().on_position(stream_name, instance_name, token) + # Also call on_rdata to ensure that stream positions are properly reset. + await self.on_rdata(stream_name, instance_name, token, []) + def stop_pusher(self, user_id, app_id, pushkey): if not self.notify_pushers: return @@ -767,7 +745,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): self.send_handler.wake_destination(server) -class FederationSenderHandler(object): +class FederationSenderHandler: """Processes the fedration replication stream This class is only instantiate on the worker responsible for sending outbound @@ -781,19 +759,11 @@ class FederationSenderHandler(object): self.federation_sender = hs.get_federation_sender() self._hs = hs - # if the worker is restarted, we want to pick up where we left off in - # the replication stream, so load the position from the database. - # - # XXX is this actually worthwhile? Whenever the master is restarted, we'll - # drop some rows anyway (which is mostly fine because we're only dropping - # typing and presence notifications). If the replication stream is - # unreliable, why do we do all this hoop-jumping to store the position in the - # database? See also https://github.com/matrix-org/synapse/issues/7535. - # - self.federation_position = self.store.federation_out_pos_startup + # Stores the latest position in the federation stream we've gotten up + # to. This is always set before we use it. + self.federation_position = None self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - self._last_ack = self.federation_position def on_start(self): # There may be some events that are persisted but haven't been sent, @@ -901,7 +871,6 @@ class FederationSenderHandler(object): # We ACK this token over replication so that the master can drop # its in memory queues self._hs.get_tcp_replication().send_federation_ack(current_position) - self._last_ack = current_position except Exception: logger.exception("Error updating federation stream position") @@ -929,7 +898,7 @@ def start(config_options): ) if config.worker_app == "synapse.app.appservice": - if config.notify_appservices: + if config.appservice.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -939,13 +908,13 @@ def start(config_options): sys.exit(1) # Force the appservice to start since they will be disabled in the main config - config.notify_appservices = True + config.appservice.notify_appservices = True else: # For other worker types we force this to off. - config.notify_appservices = False + config.appservice.notify_appservices = False if config.worker_app == "synapse.app.pusher": - if config.start_pushers: + if config.server.start_pushers: sys.stderr.write( "\nThe pushers must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -955,13 +924,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.start_pushers = True + config.server.start_pushers = True else: # For other worker types we force this to off. - config.start_pushers = False + config.server.start_pushers = False if config.worker_app == "synapse.app.user_dir": - if config.update_user_directory: + if config.server.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -971,13 +940,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.update_user_directory = True + config.server.update_user_directory = True else: # For other worker types we force this to off. - config.update_user_directory = False + config.server.update_user_directory = False if config.worker_app == "synapse.app.federation_sender": - if config.send_federation: + if config.worker.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -987,10 +956,10 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.send_federation = True + config.worker.send_federation = True else: # For other worker types we force this to off. - config.send_federation = False + config.worker.send_federation = False synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8454d74858..6014adc850 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -23,8 +23,7 @@ import math import os import resource import sys - -from six import iteritems +from typing import Iterable from prometheus_client import Gauge @@ -50,12 +49,14 @@ from synapse.app import _base from synapse.app._base import listen_ssl, listen_tcp, quit_with_error from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.additional_resource import AdditionalResource from synapse.http.server import ( OptionsResource, RootOptionsRedirectResource, RootRedirect, + StaticResource, ) from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext @@ -67,6 +68,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource +from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer @@ -89,24 +91,26 @@ def gz_wrap(r): class SynapseHomeServer(HomeServer): DATASTORE_CLASS = DataStore - def _listener_http(self, config, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - tls = listener_config.get("tls", False) - site_tag = listener_config.get("tag", port) + def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig): + port = listener_config.port + bind_addresses = listener_config.bind_addresses + tls = listener_config.tls + site_tag = listener_config.http_options.tag + if site_tag is None: + site_tag = port - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "openid" and "federation" in res["names"]: + # We always include a health resource. + resources = {"/health": HealthResource()} + + for res in listener_config.http_options.resources: + for name in res.names: + if name == "openid" and "federation" in res.names: # Skip loading openid resource if federation is defined # since federation resource will include openid continue - resources.update( - self._configure_named_resource(name, res.get("compress", False)) - ) + resources.update(self._configure_named_resource(name, res.compress)) - additional_resources = listener_config.get("additional_resources", {}) + additional_resources = listener_config.http_options.additional_resources logger.debug("Configuring additional resources: %r", additional_resources) module_api = ModuleApi(self, self.get_auth_handler()) for path, resmodule in additional_resources.items(): @@ -228,7 +232,7 @@ class SynapseHomeServer(HomeServer): if name in ["static", "client"]: resources.update( { - STATIC_PREFIX: File( + STATIC_PREFIX: StaticResource( os.path.join(os.path.dirname(synapse.__file__), "static") ) } @@ -278,7 +282,7 @@ class SynapseHomeServer(HomeServer): return resources - def start_listening(self, listeners): + def start_listening(self, listeners: Iterable[ListenerConfig]): config = self.get_config() if config.redis_enabled: @@ -288,25 +292,25 @@ class SynapseHomeServer(HomeServer): self.get_tcp_replication().start_replication(self) for listener in listeners: - if listener["type"] == "http": + if listener.type == "http": self._listening_services.extend(self._listener_http(config, listener)) - elif listener["type"] == "manhole": + elif listener.type == "manhole": listen_tcp( - listener["bind_addresses"], - listener["port"], + listener.bind_addresses, + listener.port, manhole( username="matrix", password="rabbithole", globals={"hs": self} ), ) - elif listener["type"] == "replication": + elif listener.type == "replication": services = listen_tcp( - listener["bind_addresses"], - listener["port"], + listener.bind_addresses, + listener.port, ReplicationStreamProtocolFactory(self), ) for s in services: reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) - elif listener["type"] == "metrics": + elif listener.type == "metrics": if not self.get_config().enable_metrics: logger.warning( ( @@ -315,9 +319,11 @@ class SynapseHomeServer(HomeServer): ) ) else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) + _base.listen_metrics(listener.bind_addresses, listener.port) else: - logger.warning("Unrecognized listener type: %s", listener["type"]) + # this shouldn't happen, as the listener type should have been checked + # during parsing + logger.warning("Unrecognized listener type: %s", listener.type) # Gauges to expose monthly active user control metrics @@ -377,13 +383,12 @@ def setup(config_options): hs.setup_master() - @defer.inlineCallbacks - def do_acme(): + async def do_acme() -> bool: """ Reprovision an ACME certificate, if it's required. Returns: - Deferred[bool]: Whether the cert has been updated. + Whether the cert has been updated. """ acme = hs.get_acme_handler() @@ -402,30 +407,28 @@ def setup(config_options): provision = True if provision: - yield acme.provision_certificate() + await acme.provision_certificate() return provision - @defer.inlineCallbacks - def reprovision_acme(): + async def reprovision_acme(): """ Provision a certificate from ACME, if required, and reload the TLS certificate if it's renewed. """ - reprovisioned = yield do_acme() + reprovisioned = await do_acme() if reprovisioned: _base.refresh_certificate(hs) - @defer.inlineCallbacks - def start(): + async def start(): try: # Run the ACME provisioning code, if it's enabled. if hs.config.acme_enabled: acme = hs.get_acme_handler() # Start up the webservices which we will respond to ACME # challenges with, and then provision. - yield acme.start_listening() - yield do_acme() + await acme.start_listening() + await do_acme() # Check if it needs to be reprovisioned every day. hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) @@ -434,12 +437,12 @@ def setup(config_options): if hs.config.oidc_enabled: oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. - yield defer.ensureDeferred(oidc.load_metadata()) - yield defer.ensureDeferred(oidc.load_jwks()) + await oidc.load_metadata() + await oidc.load_jwks() _base.start(hs, config.listeners) - hs.get_datastore().db.updates.start_doing_background_updates() + hs.get_datastore().db_pool.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) @@ -451,7 +454,7 @@ def setup(config_options): reactor.stop() sys.exit(1) - reactor.callWhenRunning(start) + reactor.callWhenRunning(lambda: defer.ensureDeferred(start())) return hs @@ -480,8 +483,7 @@ class SynapseService(service.Service): _stats_process = [] -@defer.inlineCallbacks -def phone_stats_home(hs, stats, stats_process=_stats_process): +async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) uptime = int(now - hs.start_time) @@ -519,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): stats["python_version"] = "{}.{}.{}".format( version.major, version.minor, version.micro ) - stats["total_users"] = yield hs.get_datastore().count_all_users() + stats["total_users"] = await hs.get_datastore().count_all_users() - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() stats["total_nonbridged_users"] = total_nonbridged_users - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in iteritems(daily_user_type_results): + daily_user_type_results = await hs.get_datastore().count_daily_user_type() + for name, count in daily_user_type_results.items(): stats["daily_user_type_" + name] = count - room_count = yield hs.get_datastore().get_room_count() + room_count = await hs.get_datastore().get_room_count() stats["total_room_count"] = room_count - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + stats["daily_active_users"] = await hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - r30_results = yield hs.get_datastore().count_r30_users() - for name, count in iteritems(r30_results): + r30_results = await hs.get_datastore().count_r30_users() + for name, count in r30_results.items(): stats["r30_users_" + name] = count - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size @@ -550,12 +552,12 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db.engine.server_version + stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: - yield hs.get_proxied_http_client().put_json( + await hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1b13e84425..13ec1f71a6 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,24 +14,25 @@ # limitations under the License. import logging import re - -from six import string_types - -from twisted.internet import defer +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes +from synapse.appservice.api import ApplicationServiceApi from synapse.types import GroupID, get_domain_from_id -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached + +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -class ApplicationServiceState(object): +class ApplicationServiceState: DOWN = "down" UP = "up" -class AppServiceTransaction(object): +class AppServiceTransaction: """Represents an application service transaction.""" def __init__(self, service, id, events): @@ -39,19 +40,19 @@ class AppServiceTransaction(object): self.id = id self.events = events - def send(self, as_api): + async def send(self, as_api: ApplicationServiceApi) -> bool: """Sends this transaction using the provided AS API interface. Args: - as_api(ApplicationServiceApi): The API to use to send. + as_api: The API to use to send. Returns: - A Deferred which resolves to True if the transaction was sent. + True if the transaction was sent. """ - return as_api.push_bulk( + return await as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id ) - def complete(self, store): + async def complete(self, store: "DataStore") -> None: """Completes this transaction as successful. Marks this transaction ID on the application service and removes the @@ -59,13 +60,11 @@ class AppServiceTransaction(object): Args: store: The database store to operate on. - Returns: - A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn(service=self.service, txn_id=self.id) + await store.complete_appservice_txn(service=self.service, txn_id=self.id) -class ApplicationService(object): +class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -156,7 +155,7 @@ class ApplicationService(object): ) regex = regex_obj.get("regex") - if isinstance(regex, string_types): + if isinstance(regex, str): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) @@ -174,8 +173,7 @@ class ApplicationService(object): return regex_obj["exclusive"] return False - @defer.inlineCallbacks - def _matches_user(self, event, store): + async def _matches_user(self, event, store): if not event: return False @@ -190,12 +188,12 @@ class ApplicationService(object): if not store: return False - does_match = yield self._matches_user_in_member_list(event.room_id, store) + does_match = await self._matches_user_in_member_list(event.room_id, store) return does_match - @cachedInlineCallbacks(num_args=1, cache_context=True) - def _matches_user_in_member_list(self, room_id, store, cache_context): - member_list = yield store.get_users_in_room( + @cached(num_args=1, cache_context=True) + async def _matches_user_in_member_list(self, room_id, store, cache_context): + member_list = await store.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) @@ -210,35 +208,33 @@ class ApplicationService(object): return self.is_interested_in_room(event.room_id) return False - @defer.inlineCallbacks - def _matches_aliases(self, event, store): + async def _matches_aliases(self, event, store): if not store or not event: return False - alias_list = yield store.get_aliases_for_room(event.room_id) + alias_list = await store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): return True return False - @defer.inlineCallbacks - def is_interested(self, event, store=None): + async def is_interested(self, event, store=None) -> bool: """Check if this service is interested in this event. Args: event(Event): The event to check. store(DataStore) Returns: - bool: True if this service would like to know about this event. + True if this service would like to know about this event. """ # Do cheap checks first if self._matches_room_id(event): return True - if (yield self._matches_aliases(event, store)): + if await self._matches_aliases(event, store): return True - if (yield self._matches_user(event, store)): + if await self._matches_user(event, store): return True return False diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 57174da021..bb6fa8299a 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,20 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six.moves import urllib +import urllib +from typing import TYPE_CHECKING, Optional from prometheus_client import Counter -from twisted.internet import defer - -from synapse.api.constants import ThirdPartyEntityKind +from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.appservice import ApplicationService + logger = logging.getLogger(__name__) sent_transactions_counter = Counter( @@ -94,14 +95,12 @@ class ApplicationServiceApi(SimpleHttpClient): hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - @defer.inlineCallbacks - def query_user(self, service, user_id): + async def query_user(self, service, user_id): if service.url is None: return False uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) - response = None try: - response = yield self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -112,14 +111,12 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_user to %s threw exception %s", uri, ex) return False - @defer.inlineCallbacks - def query_alias(self, service, alias): + async def query_alias(self, service, alias): if service.url is None: return False uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) - response = None try: - response = yield self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -130,8 +127,7 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) return False - @defer.inlineCallbacks - def query_3pe(self, service, kind, protocol, fields): + async def query_3pe(self, service, kind, protocol, fields): if kind == ThirdPartyEntityKind.USER: required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: @@ -148,7 +144,7 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - response = yield self.get_json(uri, fields) + response = await self.get_json(uri, fields) if not isinstance(response, list): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response @@ -169,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_3pe to %s threw exception %s", uri, ex) return [] - def get_3pe_protocol(self, service, protocol): + async def get_3pe_protocol( + self, service: "ApplicationService", protocol: str + ) -> Optional[JsonDict]: if service.url is None: return {} - @defer.inlineCallbacks - def _get(): + async def _get() -> Optional[JsonDict]: uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, urllib.parse.quote(protocol), ) try: - info = yield self.get_json(uri, {}) + info = await self.get_json(uri, {}) if not _is_valid_3pe_metadata(info): logger.warning( @@ -202,14 +199,13 @@ class ApplicationServiceApi(SimpleHttpClient): return None key = (service.id, protocol) - return self.protocol_meta_cache.wrap(key, _get) + return await self.protocol_meta_cache.wrap(key, _get) - @defer.inlineCallbacks - def push_bulk(self, service, events, txn_id=None): + async def push_bulk(self, service, events, txn_id=None): if service.url is None: return True - events = self._serialize(events) + events = self._serialize(service, events) if txn_id is None: logger.warning( @@ -220,7 +216,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: - yield self.put_json( + await self.put_json( uri=uri, json_body={"events": events}, args={"access_token": service.hs_token}, @@ -235,6 +231,18 @@ class ApplicationServiceApi(SimpleHttpClient): failed_transactions_counter.labels(service.id).inc() return False - def _serialize(self, events): + def _serialize(self, service, events): time_now = self.clock.time_msec() - return [serialize_event(e, time_now, as_client_event=True) for e in events] + return [ + serialize_event( + e, + time_now, + as_client_event=True, + is_invite=( + e.type == EventTypes.Member + and e.membership == "invite" + and service.is_interested_in_user(e.state_key) + ), + ) + for e in events + ] diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 9998f822f1..8eb8c6f51c 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -50,8 +50,6 @@ components. """ import logging -from twisted.internet import defer - from synapse.appservice import ApplicationServiceState from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -59,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) -class ApplicationServiceScheduler(object): +class ApplicationServiceScheduler: """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. @@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object): self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) - @defer.inlineCallbacks - def start(self): + async def start(self): logger.info("Starting appservice scheduler") # check for any DOWN ASes and start recoverers for them. - services = yield self.store.get_appservices_by_state( + services = await self.store.get_appservices_by_state( ApplicationServiceState.DOWN ) @@ -89,7 +86,7 @@ class ApplicationServiceScheduler(object): self.queuer.enqueue(service, event) -class _ServiceQueuer(object): +class _ServiceQueuer: """Queue of events waiting to be sent to appservices. Groups events into transactions per-appservice, and sends them on to the @@ -117,8 +114,7 @@ class _ServiceQueuer(object): "as-sender-%s" % (service.id,), self._send_request, service ) - @defer.inlineCallbacks - def _send_request(self, service): + async def _send_request(self, service): # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -130,14 +126,14 @@ class _ServiceQueuer(object): if not events: return try: - yield self.txn_ctrl.send(service, events) + await self.txn_ctrl.send(service, events) except Exception: logger.exception("AS request failed") finally: self.requests_in_flight.discard(service.id) -class _TransactionController(object): +class _TransactionController: """Transaction manager. Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer @@ -162,36 +158,33 @@ class _TransactionController(object): # for UTs self.RECOVERER_CLASS = _Recoverer - @defer.inlineCallbacks - def send(self, service, events): + async def send(self, service, events): try: - txn = yield self.store.create_appservice_txn(service=service, events=events) - service_is_up = yield self._is_service_up(service) + txn = await self.store.create_appservice_txn(service=service, events=events) + service_is_up = await self._is_service_up(service) if service_is_up: - sent = yield txn.send(self.as_api) + sent = await txn.send(self.as_api) if sent: - yield txn.complete(self.store) + await txn.complete(self.store) else: run_in_background(self._on_txn_fail, service) except Exception: logger.exception("Error creating appservice transaction") run_in_background(self._on_txn_fail, service) - @defer.inlineCallbacks - def on_recovered(self, recoverer): + async def on_recovered(self, recoverer): logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) self.recoverers.pop(recoverer.service.id) logger.info("Remaining active recoverers: %s", len(self.recoverers)) - yield self.store.set_appservice_state( + await self.store.set_appservice_state( recoverer.service, ApplicationServiceState.UP ) - @defer.inlineCallbacks - def _on_txn_fail(self, service): + async def _on_txn_fail(self, service): try: - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + await self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") @@ -211,13 +204,12 @@ class _TransactionController(object): recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) - @defer.inlineCallbacks - def _is_service_up(self, service): - state = yield self.store.get_appservice_state(service) + async def _is_service_up(self, service): + state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None -class _Recoverer(object): +class _Recoverer: """Manages retries and backoff for a DOWN appservice. We have one of these for each appservice which is currently considered DOWN. @@ -254,25 +246,24 @@ class _Recoverer(object): self.backoff_counter += 1 self.recover() - @defer.inlineCallbacks - def retry(self): + async def retry(self): logger.info("Starting retries on %s", self.service.id) try: while True: - txn = yield self.store.get_oldest_unsent_txn(self.service) + txn = await self.store.get_oldest_unsent_txn(self.service) if not txn: # nothing left: we're done! - self.callback(self) + await self.callback(self) return logger.info( "Retrying transaction %s for AS ID %s", txn.id, txn.service.id ) - sent = yield txn.send(self.as_api) + sent = await txn.send(self.as_api) if not sent: break - yield txn.complete(self.store) + await txn.complete(self.store) # reset the backoff counter and then process the next transaction self.backoff_counter = 1 diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index fca35b008c..65043d5b5b 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -16,6 +16,7 @@ from synapse.config._base import ConfigError if __name__ == "__main__": import sys + from synapse.config.homeserver import HomeServerConfig action = sys.argv[1] diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 30d1050a91..ad5ab6ad62 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,12 +18,16 @@ import argparse import errno import os +import time +import urllib.parse from collections import OrderedDict +from hashlib import sha256 from textwrap import dedent -from typing import Any, MutableMapping, Optional - -from six import integer_types +from typing import Any, Callable, List, MutableMapping, Optional +import attr +import jinja2 +import pkg_resources import yaml @@ -84,7 +88,7 @@ def path_exists(file_path): return False -class Config(object): +class Config: """ A configuration section, containing configuration keys and values. @@ -100,6 +104,11 @@ class Config(object): def __init__(self, root_config=None): self.root = root_config + # Get the path to the default Synapse template directory + self.default_template_dir = pkg_resources.resource_filename( + "synapse", "res/templates" + ) + def __getattr__(self, item: str) -> Any: """ Try and fetch a configuration option that does not exist on this class. @@ -117,7 +126,7 @@ class Config(object): @staticmethod def parse_size(value): - if isinstance(value, integer_types): + if isinstance(value, int): return value sizes = {"K": 1024, "M": 1024 * 1024} size = 1 @@ -129,7 +138,7 @@ class Config(object): @staticmethod def parse_duration(value): - if isinstance(value, integer_types): + if isinstance(value, int): return value second = 1000 minute = 60 * second @@ -184,8 +193,97 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() + def read_templates( + self, filenames: List[str], custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Load a list of template files from disk using the given variables. + + This function will attempt to load the given templates from the default Synapse + template directory. If `custom_template_directory` is supplied, that directory + is tried first. + + Files read are treated as Jinja templates. These templates are not rendered yet. + + Args: + filenames: A list of template filenames to read. + + custom_template_directory: A directory to try to look for the templates + before using the default Synapse template directory instead. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A list of jinja2 templates. + """ + templates = [] + search_directories = [self.default_template_dir] + + # The loader will first look in the custom template directory (if specified) for the + # given filename. If it doesn't find it, it will use the default template dir instead + if custom_template_directory: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.insert(0, custom_template_directory) + + loader = jinja2.FileSystemLoader(search_directories) + env = jinja2.Environment(loader=loader, autoescape=True) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) + + for filename in filenames: + # Load the template + template = env.get_template(filename) + templates.append(template) + + return templates -class RootConfig(object): + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +class RootConfig: """ Holder of an application's configuration. @@ -719,4 +817,36 @@ def find_config_files(search_paths): return config_files -__all__ = ["Config", "RootConfig"] +@attr.s +class ShardedWorkerHandlingConfig: + """Algorithm for choosing which instance is responsible for handling some + sharded work. + + For example, the federation senders use this to determine which instances + handles sending stuff to a given destination (which is used as the `key` + below). + """ + + instances = attr.ib(type=List[str]) + + def should_handle(self, instance_name: str, key: str) -> bool: + """Whether this instance is responsible for handling the given key. + """ + + # If multiple instances are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of instances and + # then checking whether this instance matches the instance at that + # index. + # + # (Technically this introduces some bias and is not entirely uniform, + # but since the hash is so large the bias is ridiculously small). + dest_hash = sha256(key.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 9e576060d4..eb911e8f9f 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -137,3 +137,8 @@ class Config: def read_config_files(config_files: List[str]): ... def find_config_files(search_paths: List[str]): ... + +class ShardedWorkerHandlingConfig: + instances: List[str] + def __init__(self, instances: List[str]) -> None: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... diff --git a/synapse/config/_util.py b/synapse/config/_util.py new file mode 100644 index 0000000000..cd31b1c3c9 --- /dev/null +++ b/synapse/config/_util.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List + +import jsonschema + +from synapse.config._base import ConfigError +from synapse.types import JsonDict + + +def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None: + """Validates a config setting against a JsonSchema definition + + This can be used to validate a section of the config file against a schema + definition. If the validation fails, a ConfigError is raised with a textual + description of the problem. + + Args: + json_schema: the schema to validate against + config: the configuration value to be validated + config_path: the path within the config file. This will be used as a basis + for the error message. + """ + try: + jsonschema.validate(config, json_schema) + except jsonschema.ValidationError as e: + # copy `config_path` before modifying it. + path = list(config_path) + for p in list(e.path): + if isinstance(p, int): + path.append("<item %i>" % p) + else: + path.append(str(p)) + + raise ConfigError( + "Unable to parse configuration: %s at %s" % (e.message, ".".join(path)) + ) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index ca43e96bd1..8ed3e24258 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -14,9 +14,7 @@ import logging from typing import Dict - -from six import string_types -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse import yaml from netaddr import IPSet @@ -98,17 +96,14 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: - if not isinstance(as_info.get(field), string_types): + if not isinstance(as_info.get(field), str): raise KeyError( "Required string field: '%s' (%s)" % (field, config_filename) ) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if ( - not isinstance(as_info.get("url"), string_types) - and as_info.get("url", "") is not None - ): + if not isinstance(as_info.get("url"), str) and as_info.get("url", "") is not None: raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) @@ -138,7 +133,7 @@ def _load_appservice(hostname, as_info, config_filename): ns, regex_obj, ) - if not isinstance(regex_obj.get("regex"), string_types): + if not isinstance(regex_obj.get("regex"), str): raise ValueError("Missing/bad type 'regex' key in %s", regex_obj) if not isinstance(regex_obj.get("exclusive"), bool): raise ValueError( diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 0672538796..8e03f14005 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -15,6 +15,7 @@ import os import re +import threading from typing import Callable, Dict from ._base import Config, ConfigError @@ -25,11 +26,14 @@ _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" # Map from canonicalised cache name to cache. _CACHES = {} +# a lock on the contents of _CACHES +_CACHES_LOCK = threading.Lock() + _DEFAULT_FACTOR_SIZE = 0.5 _DEFAULT_EVENT_CACHE_SIZE = "10K" -class CacheProperties(object): +class CacheProperties: def __init__(self): # The default factor size for all caches self.default_factor_size = float( @@ -66,7 +70,10 @@ def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): # Some caches have '*' in them which we strip out. cache_name = _canonicalise_cache_name(cache_name) - _CACHES[cache_name] = cache_resize_callback + # sometimes caches are initialised from background threads, so we need to make + # sure we don't conflict with another thread running a resize operation + with _CACHES_LOCK: + _CACHES[cache_name] = cache_resize_callback # Ensure all loaded caches are sized appropriately # @@ -87,7 +94,8 @@ class CacheConfig(Config): os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) ) properties.resize_all_caches_func = None - _CACHES.clear() + with _CACHES_LOCK: + _CACHES.clear() def generate_config_section(self, **kwargs): return """\ @@ -193,6 +201,8 @@ class CacheConfig(Config): For each cache, run the mapped callback function with either a specific cache factor or the default, global one. """ - for cache_name, callback in _CACHES.items(): - new_factor = self.cache_factors.get(cache_name, self.global_factor) - callback(new_factor) + # block other threads from modifying _CACHES while we iterate it. + with _CACHES_LOCK: + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/database.py b/synapse/config/database.py index 1064c2697b..8a18a9ca2a 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -55,7 +55,7 @@ DEFAULT_CONFIG = """\ #database: # name: psycopg2 # args: -# user: synapse +# user: synapse_user # password: secretpassword # database: synapse # host: localhost @@ -100,7 +100,10 @@ class DatabaseConnectionConfig: self.name = name self.config = db_config - self.data_stores = data_stores + + # The `data_stores` config is actually talking about `databases` (we + # changed the name). + self.databases = data_stores class DatabaseConfig(Config): diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index ca61214454..7a796996c0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import print_function # This file can't be called email.py because if it is, we cannot: @@ -23,7 +22,7 @@ import os from enum import Enum from typing import Optional -import pkg_resources +import attr from ._base import Config, ConfigError @@ -33,6 +32,33 @@ Password reset emails are enabled on this homeserver due to a partial %s """ +DEFAULT_SUBJECTS = { + "message_from_person_in_room": "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room...", + "message_from_person": "[%(app)s] You have a message on %(app)s from %(person)s...", + "messages_from_person": "[%(app)s] You have messages on %(app)s from %(person)s...", + "messages_in_room": "[%(app)s] You have messages on %(app)s in the %(room)s room...", + "messages_in_room_and_others": "[%(app)s] You have messages on %(app)s in the %(room)s room and others...", + "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...", + "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...", + "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...", + "password_reset": "[%(server_name)s] Password reset", + "email_validation": "[%(server_name)s] Validate your email", +} + + +@attr.s +class EmailSubjectConfig: + message_from_person_in_room = attr.ib(type=str) + message_from_person = attr.ib(type=str) + messages_from_person = attr.ib(type=str) + messages_in_room = attr.ib(type=str) + messages_in_room_and_others = attr.ib(type=str) + messages_from_person_and_others = attr.ib(type=str) + invite_from_person = attr.ib(type=str) + invite_from_person_to_room = attr.ib(type=str) + password_reset = attr.ib(type=str) + email_validation = attr.ib(type=str) + class EmailConfig(Config): section = "email" @@ -71,21 +97,18 @@ class EmailConfig(Config): if parsed[1] == "": raise RuntimeError("Invalid notif_from address") + # A user-configurable template directory template_dir = email_config.get("template_dir") - # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxilliary templates like mail.css we will need). - # (Note that loading as package_resources with jinja.PackageLoader doesn't - # work for the same reason.) - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - self.email_template_dir = os.path.abspath(template_dir) + if isinstance(template_dir, str): + # We need an absolute path, because we change directory after starting (and + # we don't yet know what auxiliary templates like mail.css we will need). + template_dir = os.path.abspath(template_dir) + elif template_dir is not None: + # If template_dir is something other than a str or None, warn the user + raise ConfigError("Config option email.template_dir must be type str") self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_config = config.get("account_validity") or {} - account_validity_renewal_enabled = account_validity_config.get("renew_at") - self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email # is not defined @@ -139,19 +162,6 @@ class EmailConfig(Config): email_config.get("validation_token_lifetime", "1h") ) - if ( - self.email_enable_notifs - or account_validity_renewal_enabled - or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): - # make sure we can import the required deps - import jinja2 - import bleach - - # prevent unused warnings - jinja2 - bleach - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: missing = [] if not self.email_notif_from: @@ -169,49 +179,49 @@ class EmailConfig(Config): # These email templates have placeholders in them, and thus must be # parsed using a templating engine during a request - self.email_password_reset_template_html = email_config.get( + password_reset_template_html = email_config.get( "password_reset_template_html", "password_reset.html" ) - self.email_password_reset_template_text = email_config.get( + password_reset_template_text = email_config.get( "password_reset_template_text", "password_reset.txt" ) - self.email_registration_template_html = email_config.get( + registration_template_html = email_config.get( "registration_template_html", "registration.html" ) - self.email_registration_template_text = email_config.get( + registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) - self.email_add_threepid_template_html = email_config.get( + add_threepid_template_html = email_config.get( "add_threepid_template_html", "add_threepid.html" ) - self.email_add_threepid_template_text = email_config.get( + add_threepid_template_text = email_config.get( "add_threepid_template_text", "add_threepid.txt" ) - self.email_password_reset_template_failure_html = email_config.get( + password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) - self.email_registration_template_failure_html = email_config.get( + registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) - self.email_add_threepid_template_failure_html = email_config.get( + add_threepid_template_failure_html = email_config.get( "add_threepid_template_failure_html", "add_threepid_failure.html" ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup - email_password_reset_template_success_html = email_config.get( + password_reset_template_success_html = email_config.get( "password_reset_template_success_html", "password_reset_success.html" ) - email_registration_template_success_html = email_config.get( + registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) - email_add_threepid_template_success_html = email_config.get( + add_threepid_template_success_html = email_config.get( "add_threepid_template_success_html", "add_threepid_success.html" ) - # Check templates exist - for f in [ + # Read all templates from disk + ( self.email_password_reset_template_html, self.email_password_reset_template_text, self.email_registration_template_html, @@ -221,32 +231,36 @@ class EmailConfig(Config): self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, - email_password_reset_template_success_html, - email_registration_template_success_html, - email_add_threepid_template_success_html, - ]: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p,)) - - # Retrieve content of web templates - filepath = os.path.join( - self.email_template_dir, email_password_reset_template_success_html - ) - self.email_password_reset_template_success_html = self.read_file( - filepath, "email.password_reset_template_success_html" + password_reset_template_success_html_template, + registration_template_success_html_template, + add_threepid_template_success_html_template, + ) = self.read_templates( + [ + password_reset_template_html, + password_reset_template_text, + registration_template_html, + registration_template_text, + add_threepid_template_html, + add_threepid_template_text, + password_reset_template_failure_html, + registration_template_failure_html, + add_threepid_template_failure_html, + password_reset_template_success_html, + registration_template_success_html, + add_threepid_template_success_html, + ], + template_dir, ) - filepath = os.path.join( - self.email_template_dir, email_registration_template_success_html - ) - self.email_registration_template_success_html_content = self.read_file( - filepath, "email.registration_template_success_html" + + # Render templates that do not contain any placeholders + self.email_password_reset_template_success_html_content = ( + password_reset_template_success_html_template.render() ) - filepath = os.path.join( - self.email_template_dir, email_add_threepid_template_success_html + self.email_registration_template_success_html_content = ( + registration_template_success_html_template.render() ) - self.email_add_threepid_template_success_html_content = self.read_file( - filepath, "email.add_threepid_template_success_html" + self.email_add_threepid_template_success_html_content = ( + add_threepid_template_success_html_template.render() ) if self.email_enable_notifs: @@ -263,17 +277,19 @@ class EmailConfig(Config): % (", ".join(missing),) ) - self.email_notif_template_html = email_config.get( + notif_template_html = email_config.get( "notif_template_html", "notif_mail.html" ) - self.email_notif_template_text = email_config.get( + notif_template_text = email_config.get( "notif_template_text", "notif_mail.txt" ) - for f in self.email_notif_template_text, self.email_notif_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.email_notif_template_html, + self.email_notif_template_text, + ) = self.read_templates( + [notif_template_html, notif_template_text], template_dir, + ) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True @@ -282,21 +298,32 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if account_validity_renewal_enabled: - self.email_expiry_template_html = email_config.get( + if self.account_validity.renew_by_email_enabled: + expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) - self.email_expiry_template_text = email_config.get( + expiry_template_text = email_config.get( "expiry_template_text", "notice_expiry.txt" ) - for f in self.email_expiry_template_text, self.email_expiry_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.account_validity_template_html, + self.account_validity_template_text, + ) = self.read_templates( + [expiry_template_html, expiry_template_text], template_dir, + ) + + subjects_config = email_config.get("subjects", {}) + subjects = {} + + for key, default in DEFAULT_SUBJECTS.items(): + subjects[key] = subjects_config.get(key, default) + + self.email_subjects = EmailSubjectConfig(**subjects) def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """\ + return ( + """\ # Configuration for sending emails from Synapse. # email: @@ -324,17 +351,17 @@ class EmailConfig(Config): # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # - # The placeholder '%(app)s' will be replaced by the application name, + # The placeholder '%%(app)s' will be replaced by the application name, # which is normally 'app_name' (below), but may be overridden by the # Matrix client application. # - # Note that the placeholder must be written '%(app)s', including the + # Note that the placeholder must be written '%%(app)s', including the # trailing 's'. # - #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" + #notif_from: "Your Friendly %%(app)s homeserver <noreply@example.com>" - # app_name defines the default value for '%(app)s' in notif_from. It - # defaults to 'Matrix'. + # app_name defines the default value for '%%(app)s' in notif_from and email + # subjects. It defaults to 'Matrix'. # #app_name: my_branded_matrix_server @@ -364,9 +391,7 @@ class EmailConfig(Config): # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # @@ -402,7 +427,76 @@ class EmailConfig(Config): # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # #template_dir: "res/templates" + + # Subjects to use when sending emails from Synapse. + # + # The placeholder '%%(app)s' will be replaced with the value of the 'app_name' + # setting above, or by a value dictated by the Matrix client application. + # + # If a subject isn't overridden in this configuration file, the value used as + # its example will be used. + # + #subjects: + + # Subjects for notification emails. + # + # On top of the '%%(app)s' placeholder, these can use the following + # placeholders: + # + # * '%%(person)s', which will be replaced by the display name of the user(s) + # that sent the message(s), e.g. "Alice and Bob". + # * '%%(room)s', which will be replaced by the name of the room the + # message(s) have been sent to, e.g. "My super room". + # + # See the example provided for each setting to see which placeholder can be + # used and how to use them. + # + # Subject to use to notify about one message from one or more user(s) in a + # room which has a name. + #message_from_person_in_room: "%(message_from_person_in_room)s" + # + # Subject to use to notify about one message from one or more user(s) in a + # room which doesn't have a name. + #message_from_person: "%(message_from_person)s" + # + # Subject to use to notify about multiple messages from one or more users in + # a room which doesn't have a name. + #messages_from_person: "%(messages_from_person)s" + # + # Subject to use to notify about multiple messages in a room which has a + # name. + #messages_in_room: "%(messages_in_room)s" + # + # Subject to use to notify about multiple messages in multiple rooms. + #messages_in_room_and_others: "%(messages_in_room_and_others)s" + # + # Subject to use to notify about multiple messages from multiple persons in + # multiple rooms. This is similar to the setting above except it's used when + # the room in which the notification was triggered has no name. + #messages_from_person_and_others: "%(messages_from_person_and_others)s" + # + # Subject to use to notify about an invite to a room which has a name. + #invite_from_person_to_room: "%(invite_from_person_to_room)s" + # + # Subject to use to notify about an invite to a room which doesn't have a + # name. + #invite_from_person: "%(invite_from_person)s" + + # Subject for emails related to account administration. + # + # On top of the '%%(app)s' placeholder, these one can use the + # '%%(server_name)s' placeholder, which will be replaced by the value of the + # 'server_name' setting in your Synapse configuration. + # + # Subject to use when sending a password reset email. + #password_reset: "%(password_reset)s" + # + # Subject to use when sending a verification email to assert an address's + # ownership. + #email_validation: "%(email_validation)s" """ + % DEFAULT_SUBJECTS + ) class ThreepidBehaviour(Enum): diff --git a/synapse/config/federation.py b/synapse/config/federation.py new file mode 100644 index 0000000000..2c77d8f85b --- /dev/null +++ b/synapse/config/federation.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from netaddr import IPSet + +from ._base import Config, ConfigError + + +class FederationConfig(Config): + section = "federation" + + def read_config(self, config, **kwargs): + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None # type: Optional[dict] + federation_domain_whitelist = config.get("federation_domain_whitelist", None) + + if federation_domain_whitelist is not None: + # turn the whitelist into a hash for speed of lookup + self.federation_domain_whitelist = {} + + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + + self.federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", [] + ) + + # Attempt to create an IPSet from the given ranges + try: + self.federation_ip_range_blacklist = IPSet( + self.federation_ip_range_blacklist + ) + + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + #federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + + # Prevent federation requests from being sent to the following + # blacklist IP address CIDR ranges. If this option is not specified, or + # specified with an empty list, no ip range blacklist will be enforced. + # + # As of Synapse v1.4.0 this option also affects any outbound requests to identity + # servers provided by user input. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 2c7b3a699f..556e291495 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -23,6 +23,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig from .key import KeyConfig @@ -36,6 +37,7 @@ from .ratelimiting import RatelimitConfig from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig +from .room import RoomConfig from .room_directory import RoomDirectoryConfig from .saml2_config import SAML2Config from .server import ServerConfig @@ -56,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + FederationConfig, CacheConfig, DatabaseConfig, LoggingConfig, @@ -75,10 +78,10 @@ class HomeServerConfig(RootConfig): JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig, SpamCheckerConfig, + RoomConfig, GroupsConfig, UserDirectoryConfig, ConsentConfig, @@ -87,5 +90,7 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, + WorkerConfig, RedisConfig, + FederationConfig, ] diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index a568726985..3252ad9e7f 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -32,6 +32,11 @@ class JWTConfig(Config): self.jwt_secret = jwt_config["secret"] self.jwt_algorithm = jwt_config["algorithm"] + # The issuer and audiences are optional, if provided, it is asserted + # that the claims exist on the JWT. + self.jwt_issuer = jwt_config.get("issuer") + self.jwt_audiences = jwt_config.get("audiences") + try: import jwt @@ -42,13 +47,63 @@ class JWTConfig(Config): self.jwt_enabled = False self.jwt_secret = None self.jwt_algorithm = None + self.jwt_issuer = None + self.jwt_audiences = None def generate_config_section(self, **kwargs): return """\ - # The JWT needs to contain a globally unique "sub" (subject) claim. + # JSON web token integration. The following settings can be used to make + # Synapse JSON web tokens for authentication, instead of its internal + # password database. + # + # Each JSON Web Token needs to contain a "sub" (subject) claim, which is + # used as the localpart of the mxid. + # + # Additionally, the expiration time ("exp"), not before time ("nbf"), + # and issued at ("iat") claims are validated if present. + # + # Note that this is a non-standard login type and client support is + # expected to be non-existant. + # + # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. # #jwt_config: - # enabled: true - # secret: "a secret" - # algorithm: "HS256" + # Uncomment the following to enable authorization using JSON web + # tokens. Defaults to false. + # + #enabled: true + + # This is either the private shared secret or the public key used to + # decode the contents of the JSON web token. + # + # Required if 'enabled' is true. + # + #secret: "provided-by-your-issuer" + + # The algorithm used to sign the JSON web token. + # + # Supported algorithms are listed at + # https://pyjwt.readthedocs.io/en/latest/algorithms.html + # + # Required if 'enabled' is true. + # + #algorithm: "provided-by-your-issuer" + + # The issuer to validate the "iss" claim against. + # + # Optional, if provided the "iss" claim will be required and + # validated for all JSON web tokens. + # + #issuer: "provided-by-your-issuer" + + # A list of audiences to validate the "aud" claim against. + # + # Optional, if provided the "aud" claim will be required and + # validated for all JSON web tokens. + # + # Note that if the "aud" claim is included in a JSON web token then + # validation will fail without configuring audiences. + # + #audiences: + # - "provided-by-your-issuer" """ diff --git a/synapse/config/key.py b/synapse/config/key.py index b529ea5da0..de964dff13 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -82,7 +82,7 @@ logger = logging.getLogger(__name__) @attr.s -class TrustedKeyServer(object): +class TrustedKeyServer: # string: name of the server. server_name = attr.ib() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 49f6c32beb..c96e6ef62a 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -55,24 +55,33 @@ formatters: format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ %(request)s - %(message)s' -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - handlers: file: - class: logging.handlers.RotatingFileHandler + class: logging.handlers.TimedRotatingFileHandler formatter: precise filename: ${log_file} - maxBytes: 104857600 - backupCount: 10 - filters: [context] + when: midnight + backupCount: 3 # Does not include the current log file. encoding: utf8 + + # Default to buffering writes to log file for efficiency. This means that + # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR + # logs will still be flushed immediately. + buffer: + class: logging.handlers.MemoryHandler + target: file + # The capacity is the number of log lines that are buffered before + # being written to disk. Increasing this will lead to better + # performance, at the expensive of it taking longer for log lines to + # be written to disk. + capacity: 10 + flushLevel: 30 # Flush for WARNING logs as well + + # A handler that writes logs to stderr. Unused by default, but can be used + # instead of "buffer" and "file" in the logger handlers. console: class: logging.StreamHandler formatter: precise - filters: [context] loggers: synapse.storage.SQL: @@ -80,9 +89,24 @@ loggers: # information such as access tokens. level: INFO + twisted: + # We send the twisted logging directly to the file handler, + # to work around https://github.com/matrix-org/synapse/issues/3471 + # when using "buffer" logger. Use "console" to log to stderr instead. + handlers: [file] + propagate: false + root: level: INFO - handlers: [file, console] + + # Write logs to the `buffer` handler, which will buffer them together in memory, + # then write them to a file. + # + # Replace "buffer" with "console" to log to stderr instead. (Note that you'll + # also need to update the configuation for the `twisted` logger above, in + # this case.) + # + handlers: [buffer] disable_existing_loggers: false """ @@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): handler = logging.StreamHandler() handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) logger.addHandler(handler) else: logging.config.dictConfig(log_config) + # We add a log record factory that runs all messages through the + # LoggingContextFilter so that we get the context *at the time we log* + # rather than when we write to a handler. This can be done in config using + # filter options, but care must when using e.g. MemoryHandler to buffer + # writes. + + log_filter = LoggingContextFilter(request="") + old_factory = logging.getLogRecordFactory() + + def factory(*args, **kwargs): + record = old_factory(*args, **kwargs) + log_filter.filter(record) + return record + + logging.setLogRecordFactory(factory) + # Route Twisted's native logging through to the standard library logging # system. observer = STDLibLogObserver() @@ -214,7 +253,7 @@ def setup_logging( Set up the logging subsystem. Args: - config (LoggingConfig | synapse.config.workers.WorkerConfig): + config (LoggingConfig | synapse.config.worker.WorkerConfig): configuration data use_worker_options (bool): True to use the 'worker_log_config' option diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 6aad0d37c0..dfd27e1523 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -22,7 +22,7 @@ from ._base import Config, ConfigError @attr.s -class MetricsFlags(object): +class MetricsFlags: known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool)) @classmethod diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e24dd637bc..e0939bce84 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -89,7 +89,7 @@ class OIDCConfig(Config): # use an OpenID Connect Provider for authentication, instead of its internal # password database. # - # See https://github.com/matrix-org/synapse/blob/master/openid.md. + # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md. # oidc_config: # Uncomment the following to enable authorization against an OpenID Connect diff --git a/synapse/config/push.py b/synapse/config/push.py index 6f2b3a7faa..a1f3752c8a 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ShardedWorkerHandlingConfig class PushConfig(Config): @@ -24,6 +24,9 @@ class PushConfig(Config): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) + pusher_instances = config.get("pusher_instances") or [] + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) + # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 2dd94bae2b..14b8836197 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -17,7 +17,7 @@ from typing import Dict from ._base import Config -class RateLimitConfig(object): +class RateLimitConfig: def __init__( self, config: Dict[str, float], @@ -27,7 +27,7 @@ class RateLimitConfig(object): self.burst_count = config.get("burst_count", defaults["burst_count"]) -class FederationRateLimitConfig(object): +class FederationRateLimitConfig: _items_and_default = { "window_size": 1000, "sleep_limit": 10, @@ -93,6 +93,15 @@ class RatelimitConfig(Config): if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) + self.rc_joins_local = RateLimitConfig( + config.get("rc_joins", {}).get("local", {}), + defaults={"per_second": 0.1, "burst_count": 3}, + ) + self.rc_joins_remote = RateLimitConfig( + config.get("rc_joins", {}).get("remote", {}), + defaults={"per_second": 0.01, "burst_count": 3}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -118,6 +127,10 @@ class RatelimitConfig(Config): # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. + # - two for ratelimiting number of rooms a user can join, "local" for when + # users are joining rooms the server is already in (this is cheap) vs + # "remote" for when users are trying to join rooms not on the server (which + # can be more expensive) # # The defaults are as shown below. # @@ -143,6 +156,14 @@ class RatelimitConfig(Config): #rc_admin_redaction: # per_second: 1 # burst_count: 50 + # + #rc_joins: + # local: + # per_second: 0.1 + # burst_count: 3 + # remote: + # per_second: 0.01 + # burst_count: 3 # Ratelimiting settings for incoming federation diff --git a/synapse/config/redis.py b/synapse/config/redis.py index d5d3ca1c9e..1373302335 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -21,7 +21,7 @@ class RedisConfig(Config): section = "redis" def read_config(self, config, **kwargs): - redis_config = config.get("redis", {}) + redis_config = config.get("redis") or {} self.redis_enabled = redis_config.get("enabled", False) if not self.redis_enabled: @@ -32,3 +32,24 @@ class RedisConfig(Config): self.redis_host = redis_config.get("host", "localhost") self.redis_port = redis_config.get("port", 6379) self.redis_password = redis_config.get("password") + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Configuration for Redis when using workers. This *must* be enabled when + # using workers (unless using old style direct TCP configuration). + # + redis: + # Uncomment the below to enable Redis support. + # + #enabled: true + + # Optional host and port to use to connect to redis. Defaults to + # localhost and 6379 + # + #host: localhost + #port: 6379 + + # Optional password if configured on the Redis instance + # + #password: <secret_password> + """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index c7487178dc..a670080fe2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -18,8 +18,9 @@ from distutils.util import strtobool import pkg_resources +from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols @@ -130,7 +131,50 @@ class RegistrationConfig(Config): for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) + + # Options for creating auto-join rooms if they do not exist yet. self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.autocreate_auto_join_rooms_federated = config.get( + "autocreate_auto_join_rooms_federated", True + ) + self.autocreate_auto_join_room_preset = ( + config.get("autocreate_auto_join_room_preset") + or RoomCreationPreset.PUBLIC_CHAT + ) + self.auto_join_room_requires_invite = self.autocreate_auto_join_room_preset in { + RoomCreationPreset.PRIVATE_CHAT, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + } + + # Pull the creater/inviter from the configuration, this gets used to + # send invites for invite-only rooms. + mxid_localpart = config.get("auto_join_mxid_localpart") + self.auto_join_user_id = None + if mxid_localpart: + # Convert the localpart to a full mxid. + self.auto_join_user_id = UserID( + mxid_localpart, self.server_name + ).to_string() + + if self.autocreate_auto_join_rooms: + # Ensure the preset is a known value. + if self.autocreate_auto_join_room_preset not in { + RoomCreationPreset.PUBLIC_CHAT, + RoomCreationPreset.PRIVATE_CHAT, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + }: + raise ConfigError("Invalid value for autocreate_auto_join_room_preset") + # If the preset requires invitations to be sent, ensure there's a + # configured user to send them from. + if self.auto_join_room_requires_invite: + if not mxid_localpart: + raise ConfigError( + "The configuration option `auto_join_mxid_localpart` is required if " + "`autocreate_auto_join_room_preset` is set to private_chat or trusted_private_chat, such that " + "Synapse knows who to send invitations from. Please " + "configure `auto_join_mxid_localpart`." + ) + self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True) self.enable_set_displayname = config.get("enable_set_displayname", True) @@ -297,24 +341,6 @@ class RegistrationConfig(Config): # #default_identity_server: https://matrix.org - # The list of identity servers trusted to verify third party - # identifiers by this server. - # - # Also defines the ID server which will be called when an account is - # deactivated (one will be picked arbitrarily). - # - # Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity - # server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a - # background migration script, informing itself that the identity server all of its - # 3PIDs have been bound to is likely one of the below. - # - # As of Synapse v1.4.0, all other functionality of this option has been deprecated, and - # it is now solely used for the purposes of the background migration script, and can be - # removed once it has run. - #trusted_third_party_id_servers: - # - matrix.org - # - vector.im - # Handle threepid (email/phone etc) registration and password resets through a set of # *trusted* identity servers. Note that this allows the configured identity server to # reset passwords for accounts! @@ -365,7 +391,11 @@ class RegistrationConfig(Config): #enable_3pid_changes: false # Users who register on this homeserver will automatically be joined - # to these rooms + # to these rooms. + # + # By default, any room aliases included in this list will be created + # as a publicly joinable room when the first user registers for the + # homeserver. This behaviour can be customised with the settings below. # #auto_join_rooms: # - "#example:example.com" @@ -373,10 +403,62 @@ class RegistrationConfig(Config): # Where auto_join_rooms are specified, setting this flag ensures that the # the rooms exist by creating them when the first user on the # homeserver registers. + # + # By default the auto-created rooms are publicly joinable from any federated + # server. Use the autocreate_auto_join_rooms_federated and + # autocreate_auto_join_room_preset settings below to customise this behaviour. + # # Setting to false means that if the rooms are not manually created, # users cannot be auto-joined since they do not exist. # - #autocreate_auto_join_rooms: true + # Defaults to true. Uncomment the following line to disable automatically + # creating auto-join rooms. + # + #autocreate_auto_join_rooms: false + + # Whether the auto_join_rooms that are auto-created are available via + # federation. Only has an effect if autocreate_auto_join_rooms is true. + # + # Note that whether a room is federated cannot be modified after + # creation. + # + # Defaults to true: the room will be joinable from other servers. + # Uncomment the following to prevent users from other homeservers from + # joining these rooms. + # + #autocreate_auto_join_rooms_federated: false + + # The room preset to use when auto-creating one of auto_join_rooms. Only has an + # effect if autocreate_auto_join_rooms is true. + # + # This can be one of "public_chat", "private_chat", or "trusted_private_chat". + # If a value of "private_chat" or "trusted_private_chat" is used then + # auto_join_mxid_localpart must also be configured. + # + # Defaults to "public_chat", meaning that the room is joinable by anyone, including + # federated servers if autocreate_auto_join_rooms_federated is true (the default). + # Uncomment the following to require an invitation to join these rooms. + # + #autocreate_auto_join_room_preset: private_chat + + # The local part of the user id which is used to create auto_join_rooms if + # autocreate_auto_join_rooms is true. If this is not provided then the + # initial user account that registers will be used to create the rooms. + # + # The user id is also used to invite new users to any auto-join rooms which + # are set to invite-only. + # + # It *must* be configured if autocreate_auto_join_room_preset is set to + # "private_chat" or "trusted_private_chat". + # + # Note that this must be specified in order for new users to be correctly + # invited to any auto-join rooms which have been set to invite-only (either + # at the time of creation or subsequently). + # + # Note that, if the room already exists, this user must be joined and + # have the appropriate permissions to invite new members. + # + #auto_join_mxid_localpart: system # When auto_join_rooms is specified, setting this flag to false prevents # guest accounts from being automatically joined to the rooms. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b751d02d37..01009f3924 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -94,6 +94,12 @@ class ContentRepositoryConfig(Config): else: self.can_load_media_repo = True + # Whether this instance should be the one to run the background jobs to + # e.g clean up old URL previews. + self.media_instance_running_background_jobs = config.get( + "media_instance_running_background_jobs", + ) + self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M")) self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) diff --git a/synapse/config/room.py b/synapse/config/room.py new file mode 100644 index 0000000000..692d7a1936 --- /dev/null +++ b/synapse/config/room.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import RoomCreationPreset + +from ._base import Config, ConfigError + +logger = logging.Logger(__name__) + + +class RoomDefaultEncryptionTypes: + """Possible values for the encryption_enabled_by_default_for_room_type config option""" + + ALL = "all" + INVITE = "invite" + OFF = "off" + + +class RoomConfig(Config): + section = "room" + + def read_config(self, config, **kwargs): + # Whether new, locally-created rooms should have encryption enabled + encryption_for_room_type = config.get( + "encryption_enabled_by_default_for_room_type", + RoomDefaultEncryptionTypes.OFF, + ) + if encryption_for_room_type == RoomDefaultEncryptionTypes.ALL: + self.encryption_enabled_by_default_for_room_presets = [ + RoomCreationPreset.PRIVATE_CHAT, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + RoomCreationPreset.PUBLIC_CHAT, + ] + elif encryption_for_room_type == RoomDefaultEncryptionTypes.INVITE: + self.encryption_enabled_by_default_for_room_presets = [ + RoomCreationPreset.PRIVATE_CHAT, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + ] + elif ( + encryption_for_room_type == RoomDefaultEncryptionTypes.OFF + or encryption_for_room_type is False + ): + # PyYAML translates "off" into False if it's unquoted, so we also need to + # check for encryption_for_room_type being False. + self.encryption_enabled_by_default_for_room_presets = [] + else: + raise ConfigError( + "Invalid value for encryption_enabled_by_default_for_room_type" + ) + + def generate_config_section(self, **kwargs): + return """\ + ## Rooms ## + + # Controls whether locally-created rooms should be end-to-end encrypted by + # default. + # + # Possible options are "all", "invite", and "off". They are defined as: + # + # * "all": any locally-created room + # * "invite": any room created with the "private_chat" or "trusted_private_chat" + # room creation presets + # * "off": this option will take no effect + # + # The default value is "off". + # + # Note that this option will only affect rooms created after it is set. It + # will also not affect rooms created by other servers. + # + #encryption_enabled_by_default_for_room_type: invite + """ diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 7ac7699676..6de1f9d103 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -149,7 +149,7 @@ class RoomDirectoryConfig(Config): return False -class _RoomDirectoryRule(object): +class _RoomDirectoryRule: """Helper class to test whether a room directory action is allowed, like creating an alias or publishing a room. """ diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index d0a19751e8..cc7401888b 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -15,14 +15,15 @@ # limitations under the License. import logging +from typing import Any, List -import jinja2 -import pkg_resources +import attr from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +from ._util import validate_config logger = logging.getLogger(__name__) @@ -80,6 +81,11 @@ class SAML2Config(Config): self.saml2_enabled = True + attribute_requirements = saml2_config.get("attribute_requirements") or [] + self.attribute_requirements = _parse_attribute_requirements_def( + attribute_requirements + ) + self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) @@ -160,18 +166,12 @@ class SAML2Config(Config): # session lifetime: in milliseconds self.saml2_session_lifetime = self.parse_duration( - saml2_config.get("saml_session_lifetime", "5m") + saml2_config.get("saml_session_lifetime", "15m") ) - template_dir = saml2_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - - loader = jinja2.FileSystemLoader(template_dir) - # enable auto-escape here, to having to remember to escape manually in the - # template - env = jinja2.Environment(loader=loader, autoescape=True) - self.saml2_error_html_template = env.get_template("saml_error.html") + self.saml2_error_html_template = self.read_templates( + ["saml_error.html"], saml2_config.get("template_dir") + )[0] def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set @@ -286,7 +286,7 @@ class SAML2Config(Config): # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. - # The default is 5 minutes. + # The default is 15 minutes. # #saml_session_lifetime: 5m @@ -341,6 +341,17 @@ class SAML2Config(Config): # #grandfathered_mxid_source_attribute: upn + # It is possible to configure Synapse to only allow logins if SAML attributes + # match particular values. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. + # + #attribute_requirements: + # - attribute: userGroup + # value: "staff" + # - attribute: department + # value: "sales" + # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # @@ -368,3 +379,34 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } + + +@attr.s(frozen=True) +class SamlAttributeRequirement: + """Object describing a single requirement for SAML attributes.""" + + attribute = attr.ib(type=str) + value = attr.ib(type=str) + + JSON_SCHEMA = { + "type": "object", + "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, + "required": ["attribute", "value"], + } + + +ATTRIBUTE_REQUIREMENTS_SCHEMA = { + "type": "array", + "items": SamlAttributeRequirement.JSON_SCHEMA, +} + + +def _parse_attribute_requirements_def( + attribute_requirements: Any, +) -> List[SamlAttributeRequirement]: + validate_config( + ATTRIBUTE_REQUIREMENTS_SCHEMA, + attribute_requirements, + config_path=["saml2_config", "attribute_requirements"], + ) + return [SamlAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/config/server.py b/synapse/config/server.py index f57eefc99c..e85c6a0840 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,15 +19,13 @@ import logging import os.path import re from textwrap import indent -from typing import Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional import attr import yaml -from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name -from synapse.python_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError @@ -57,6 +55,64 @@ on how to configure the new listener. --------------------------------------------------------------------------------""" +KNOWN_LISTENER_TYPES = { + "http", + "metrics", + "manhole", + "replication", +} + +KNOWN_RESOURCES = { + "client", + "consent", + "federation", + "keys", + "media", + "metrics", + "openid", + "replication", + "static", + "webclient", +} + + +@attr.s(frozen=True) +class HttpResourceConfig: + names = attr.ib( + type=List[str], + factory=list, + validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)), # type: ignore + ) + compress = attr.ib( + type=bool, + default=False, + validator=attr.validators.optional(attr.validators.instance_of(bool)), # type: ignore[arg-type] + ) + + +@attr.s(frozen=True) +class HttpListenerConfig: + """Object describing the http-specific parts of the config of a listener""" + + x_forwarded = attr.ib(type=bool, default=False) + resources = attr.ib(type=List[HttpResourceConfig], factory=list) + additional_resources = attr.ib(type=Dict[str, dict], factory=dict) + tag = attr.ib(type=str, default=None) + + +@attr.s(frozen=True) +class ListenerConfig: + """Object describing the configuration of a single listener.""" + + port = attr.ib(type=int, validator=attr.validators.instance_of(int)) + bind_addresses = attr.ib(type=List[str]) + type = attr.ib(type=str, validator=attr.validators.in_(KNOWN_LISTENER_TYPES)) + tls = attr.ib(type=bool, default=False) + + # http_options is only populated if type=http + http_options = attr.ib(type=Optional[HttpListenerConfig], default=None) + + class ServerConfig(Config): section = "server" @@ -78,11 +134,6 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) - # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -155,7 +206,7 @@ class ServerConfig(Config): # errors when attempting to search for messages. self.enable_search = config.get("enable_search", True) - self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + self.filter_timeline_limit = config.get("filter_timeline_limit", 100) # Whether we should block invites sent to users on this server # (other than those sent by local server admins) @@ -205,34 +256,6 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) - # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None # type: Optional[dict] - federation_domain_whitelist = config.get("federation_domain_whitelist", None) - - if federation_domain_whitelist is not None: - # turn the whitelist into a hash for speed of lookup - self.federation_domain_whitelist = {} - - for domain in federation_domain_whitelist: - self.federation_domain_whitelist[domain] = True - - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) - - # Attempt to create an IPSet from the given ranges - try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -379,38 +402,21 @@ class ServerConfig(Config): } ] - self.listeners = [] # type: List[dict] - for listener in config.get("listeners", []): - if not isinstance(listener.get("port", None), int): - raise ConfigError( - "Listener configuration is lacking a valid 'port' option" - ) + self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])] - if listener.setdefault("tls", False): - # no_tls is not really supported any more, but let's grandfather it in - # here. - if config.get("no_tls", False): + # no_tls is not really supported any more, but let's grandfather it in + # here. + if config.get("no_tls", False): + l2 = [] + for listener in self.listeners: + if listener.tls: logger.info( - "Ignoring TLS-enabled listener on port %i due to no_tls" + "Ignoring TLS-enabled listener on port %i due to no_tls", + listener.port, ) - continue - - bind_address = listener.pop("bind_address", None) - bind_addresses = listener.setdefault("bind_addresses", []) - - # if bind_address was specified, add it to the list of addresses - if bind_address: - bind_addresses.append(bind_address) - - # if we still have an empty list of addresses, use the default list - if not bind_addresses: - if listener["type"] == "metrics": - # the metrics listener doesn't support IPv6 - bind_addresses.append("0.0.0.0") else: - bind_addresses.extend(DEFAULT_BIND_ADDRESSES) - - self.listeners.append(listener) + l2.append(listener) + self.listeners = l2 if not self.web_client_location: _warn_if_webclient_configured(self.listeners) @@ -418,7 +424,7 @@ class ServerConfig(Config): self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) @attr.s - class LimitRemoteRoomsConfig(object): + class LimitRemoteRoomsConfig: enabled = attr.ib( validator=attr.validators.instance_of(bool), default=False ) @@ -432,6 +438,9 @@ class ServerConfig(Config): validator=attr.validators.instance_of(str), default=ROOM_COMPLEXITY_TOO_GREAT, ) + admins_can_join = attr.ib( + validator=attr.validators.instance_of(bool), default=False + ) self.limit_remote_rooms = LimitRemoteRoomsConfig( **(config.get("limit_remote_rooms") or {}) @@ -446,43 +455,41 @@ class ServerConfig(Config): bind_host = config.get("bind_host", "") gzip_responses = config.get("gzip_responses", True) + http_options = HttpListenerConfig( + resources=[ + HttpResourceConfig(names=["client"], compress=gzip_responses), + HttpResourceConfig(names=["federation"]), + ], + ) + self.listeners.append( - { - "port": bind_port, - "bind_addresses": [bind_host], - "tls": True, - "type": "http", - "resources": [ - {"names": ["client"], "compress": gzip_responses}, - {"names": ["federation"], "compress": False}, - ], - } + ListenerConfig( + port=bind_port, + bind_addresses=[bind_host], + tls=True, + type="http", + http_options=http_options, + ) ) unsecure_port = config.get("unsecure_port", bind_port - 400) if unsecure_port: self.listeners.append( - { - "port": unsecure_port, - "bind_addresses": [bind_host], - "tls": False, - "type": "http", - "resources": [ - {"names": ["client"], "compress": gzip_responses}, - {"names": ["federation"], "compress": False}, - ], - } + ListenerConfig( + port=unsecure_port, + bind_addresses=[bind_host], + tls=False, + type="http", + http_options=http_options, + ) ) manhole = config.get("manhole") if manhole: self.listeners.append( - { - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - } + ListenerConfig( + port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + ) ) metrics_port = config.get("metrics_port") @@ -490,17 +497,16 @@ class ServerConfig(Config): logger.warning(METRICS_PORT_WARNING) self.listeners.append( - { - "port": metrics_port, - "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], - "tls": False, - "type": "http", - "resources": [{"names": ["metrics"], "compress": False}], - } + ListenerConfig( + port=metrics_port, + bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")], + type="http", + http_options=HttpListenerConfig( + resources=[HttpResourceConfig(names=["metrics"])] + ), + ) ) - _check_resource_config(self.listeners) - self.cleanup_extremities_with_dummy_events = config.get( "cleanup_extremities_with_dummy_events", True ) @@ -521,8 +527,23 @@ class ServerConfig(Config): "request_token_inhibit_3pid_errors", False, ) + # List of users trialing the new experimental default push rules. This setting is + # not included in the sample configuration file on purpose as it's a temporary + # hack, so that some users can trial the new defaults without impacting every + # user on the homeserver. + users_new_default_push_rules = ( + config.get("users_new_default_push_rules") or [] + ) # type: list + if not isinstance(users_new_default_push_rules, list): + raise ConfigError("'users_new_default_push_rules' must be a list") + + # Turn the list into a set to improve lookup speed. + self.users_new_default_push_rules = set( + users_new_default_push_rules + ) # type: set + def has_tls_listener(self) -> bool: - return any(listener["tls"] for listener in self.listeners) + return any(listener.tls for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs @@ -687,7 +708,9 @@ class ServerConfig(Config): #gc_thresholds: [700, 10, 10] # Set the limit on the returned events in the timeline in the get - # and sync operations. The default value is -1, means no upper limit. + # and sync operations. The default value is 100. -1 means no upper limit. + # + # Uncomment the following to increase the limit to 5000. # #filter_timeline_limit: 5000 @@ -703,38 +726,6 @@ class ServerConfig(Config): # #enable_search: false - # Restrict federation to the following whitelist of domains. - # N.B. we recommend also firewalling your federation listener to limit - # inbound federation traffic as early as possible, rather than relying - # purely on this application-layer restriction. If not specified, the - # default is to whitelist everything. - # - #federation_domain_whitelist: - # - lon.example.com - # - nyc.example.com - # - syd.example.com - - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. - # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) - # - federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -763,7 +754,7 @@ class ServerConfig(Config): # names: a list of names of HTTP resources. See below for a list of # valid resource names. # - # compress: set to true to enable HTTP comression for this resource. + # compress: set to true to enable HTTP compression for this resource. # # additional_resources: Only valid for an 'http' listener. A map of # additional endpoints which should be loaded via dynamic modules. @@ -856,7 +847,7 @@ class ServerConfig(Config): # number of monthly active users. # # 'limit_usage_by_mau' disables/enables monthly active user blocking. When - # anabled and a limit is reached the server returns a 'ResourceLimitError' + # enabled and a limit is reached the server returns a 'ResourceLimitError' # with error type Codes.RESOURCE_LIMIT_EXCEEDED # # 'max_mau_value' is the hard limit of monthly active users above which @@ -917,6 +908,10 @@ class ServerConfig(Config): # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # @@ -966,11 +961,10 @@ class ServerConfig(Config): # min_lifetime: 1d # max_lifetime: 1y - # Retention policy limits. If set, a user won't be able to send a - # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' - # that's not within this range. This is especially useful in closed federations, - # in which server admins can make sure every federating server applies the same - # rules. + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. # #allowed_lifetime_min: 1d #allowed_lifetime_max: 1y @@ -996,12 +990,19 @@ class ServerConfig(Config): # (e.g. every 12h), but not want that purge to be performed by a job that's # iterating over every room it knows, which could be heavy on the server. # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # #purge_jobs: - # - shortest_max_lifetime: 1d - # longest_max_lifetime: 3d + # - longest_max_lifetime: 3d # interval: 12h # - shortest_max_lifetime: 3d - # longest_max_lifetime: 1y # interval: 1d # Inhibits the /requestToken endpoints from returning an error that might leak @@ -1081,6 +1082,44 @@ def read_gc_thresholds(thresholds): ) +def parse_listener_def(listener: Any) -> ListenerConfig: + """parse a listener config from the config file""" + listener_type = listener["type"] + + port = listener.get("port") + if not isinstance(port, int): + raise ConfigError("Listener configuration is lacking a valid 'port' option") + + tls = listener.get("tls", False) + + bind_addresses = listener.get("bind_addresses", []) + bind_address = listener.get("bind_address") + # if bind_address was specified, add it to the list of addresses + if bind_address: + bind_addresses.append(bind_address) + + # if we still have an empty list of addresses, use the default list + if not bind_addresses: + if listener_type == "metrics": + # the metrics listener doesn't support IPv6 + bind_addresses.append("0.0.0.0") + else: + bind_addresses.extend(DEFAULT_BIND_ADDRESSES) + + http_config = None + if listener_type == "http": + http_config = HttpListenerConfig( + x_forwarded=listener.get("x_forwarded", False), + resources=[ + HttpResourceConfig(**res) for res in listener.get("resources", []) + ], + additional_resources=listener.get("additional_resources", {}), + tag=listener.get("tag"), + ) + + return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) + + NO_MORE_WEB_CLIENT_WARNING = """ Synapse no longer includes a web client. To enable a web client, configure web_client_location. To remove this warning, remove 'webclient' from the 'listeners' @@ -1088,42 +1127,12 @@ configuration. """ -def _warn_if_webclient_configured(listeners): +def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None: for listener in listeners: - for res in listener.get("resources", []): - for name in res.get("names", []): + if not listener.http_options: + continue + for res in listener.http_options.resources: + for name in res.names: if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return - - -KNOWN_RESOURCES = ( - "client", - "consent", - "federation", - "keys", - "media", - "metrics", - "openid", - "replication", - "static", - "webclient", -) - - -def _check_resource_config(listeners): - resource_names = { - res_name - for listener in listeners - for res in listener.get("resources", []) - for res_name in res.get("names", []) - } - - for resource in resource_names: - if resource not in KNOWN_RESOURCES: - raise ConfigError("Unknown listener resource '%s'" % (resource,)) - if resource == "consent": - try: - check_requirements("resources.consent") - except DependencyException as e: - raise ConfigError(e.message) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 73b7296399..4427676167 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Dict -import pkg_resources - from ._base import Config @@ -29,22 +26,32 @@ class SSOConfig(Config): def read_config(self, config, **kwargs): sso_config = config.get("sso") or {} # type: Dict[str, Any] - # Pick a template directory in order of: - # * The sso-specific template_dir - # * /path/to/synapse/install/res/templates + # The sso-specific template_dir template_dir = sso_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_template_dir = template_dir - self.sso_account_deactivated_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), - "sso_account_deactivated_template", + # Read templates from disk + ( + self.sso_redirect_confirm_template, + self.sso_auth_confirm_template, + self.sso_error_template, + sso_account_deactivated_template, + sso_auth_success_template, + ) = self.read_templates( + [ + "sso_redirect_confirm.html", + "sso_auth_confirm.html", + "sso_error.html", + "sso_account_deactivated.html", + "sso_auth_success.html", + ], + template_dir, ) - self.sso_auth_success_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_auth_success.html"), - "sso_auth_success_template", + + # These templates have no placeholders, so render them here + self.sso_account_deactivated_template = ( + sso_account_deactivated_template.render() ) + self.sso_auth_success_template = sso_auth_success_template.render() self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/config/tls.py b/synapse/config/tls.py index a65538562b..e368ea564d 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -20,8 +20,6 @@ from datetime import datetime from hashlib import sha256 from typing import List -import six - from unpaddedbase64 import encode_base64 from OpenSSL import SSL, crypto @@ -59,7 +57,7 @@ class TlsConfig(Config): logger.warning(ACME_SUPPORT_ENABLED_WARN) # hyperlink complains on py2 if this is not a Unicode - self.acme_url = six.text_type( + self.acme_url = str( acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") ) self.acme_port = acme_config.get("port", 80) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index ed06b91a54..c784a71508 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -15,7 +15,8 @@ import attr -from ._base import Config, ConfigError +from ._base import Config, ConfigError, ShardedWorkerHandlingConfig +from .server import ListenerConfig, parse_listener_def @attr.s @@ -33,9 +34,11 @@ class WriterLocations: Attributes: events: The instance that writes to the event and backfill streams. + events: The instance that writes to the typing stream. """ events = attr.ib(default="master", type=str) + typing = attr.ib(default="master", type=str) class WorkerConfig(Config): @@ -52,7 +55,9 @@ class WorkerConfig(Config): if self.worker_app == "synapse.app.homeserver": self.worker_app = None - self.worker_listeners = config.get("worker_listeners", []) + self.worker_listeners = [ + parse_listener_def(x) for x in config.get("worker_listeners", []) + ] self.worker_daemonize = config.get("worker_daemonize") self.worker_pid_file = config.get("worker_pid_file") self.worker_log_config = config.get("worker_log_config") @@ -75,23 +80,20 @@ class WorkerConfig(Config): manhole = config.get("worker_manhole") if manhole: self.worker_listeners.append( - { - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - } + ListenerConfig( + port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + ) ) - if self.worker_listeners: - for listener in self.worker_listeners: - bind_address = listener.pop("bind_address", None) - bind_addresses = listener.setdefault("bind_addresses", []) + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) - if bind_address: - bind_addresses.append(bind_address) - elif not bind_addresses: - bind_addresses.append("") + federation_sender_instances = config.get("federation_sender_instances") or [] + self.federation_shard_config = ShardedWorkerHandlingConfig( + federation_sender_instances + ) # A map from instance name to host/port of their HTTP replication endpoint. instance_map = config.get("instance_map") or {} @@ -103,16 +105,52 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events also appears in + # Check that the configured writer for events and typing also appears in # `instance_map`. - if ( - self.writers.events != "master" - and self.writers.events not in self.instance_map - ): - raise ConfigError( - "Instance %r is configured to write events but does not appear in `instance_map` config." - % (self.writers.events,) - ) + for stream in ("events", "typing"): + instance = getattr(self.writers, stream) + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + ## Workers ## + + # Disables sending of outbound federation transactions on the main process. + # Uncomment if using a federation sender worker. + # + #send_federation: false + + # It is possible to run multiple federation sender workers, in which case the + # work is balanced across them. + # + # This configuration must be shared between all federation sender workers, and if + # changed all federation sender workers must be stopped at the same time and then + # started, to ensure that all instances are running with the same config (otherwise + # events may be dropped). + # + #federation_sender_instances: + # - federation_sender1 + + # When using workers this should be a map from `worker_name` to the + # HTTP replication listener of the worker, if configured. + # + #instance_map: + # worker1: + # host: localhost + # port: 8034 + + # Experimental: When using workers you can define which workers should + # handle event persistence and typing notifications. Any worker + # specified here must also be in the `instance_map`. + # + #stream_writers: + # events: worker1 + # typing: worker1 + """ def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index a5a2a7815d..2b03f5ac76 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -48,6 +48,14 @@ class ServerContextFactory(ContextFactory): connections.""" def __init__(self, config): + # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, + # switch to those (see https://github.com/pyca/cryptography/issues/5379). + # + # note that, despite the confusing name, SSLv23_METHOD does *not* enforce SSLv2 + # or v3, but is a synonym for TLS_METHOD, which allows the client and server + # to negotiate an appropriate version of TLS constrained by the version options + # set with context.set_options. + # self._context = SSL.Context(SSL.SSLv23_METHOD) self.configure_context(self._context, config) @@ -75,7 +83,7 @@ class ServerContextFactory(ContextFactory): @implementer(IPolicyForHTTPS) -class FederationPolicyForHTTPS(object): +class FederationPolicyForHTTPS: """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers for federation. @@ -144,7 +152,7 @@ class FederationPolicyForHTTPS(object): @implementer(IPolicyForHTTPS) -class RegularPolicyForHTTPS(object): +class RegularPolicyForHTTPS: """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers, for other than federation. @@ -181,7 +189,7 @@ def _context_info_cb(ssl_connection, where, ret): @implementer(IOpenSSLClientConnectionCreator) -class SSLClientConnectionCreator(object): +class SSLClientConnectionCreator: """Creates openssl connection objects for client connections. Replaces twisted.internet.ssl.ClientTLSOptions @@ -206,7 +214,7 @@ class SSLClientConnectionCreator(object): return connection -class ConnectionVerifier(object): +class ConnectionVerifier: """Set the SNI, and do cert verification This is a thing which is attached to the TLSMemoryBIOProtocol, and is called by diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index a9f4025bfe..32c31b1cd1 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +import urllib from collections import defaultdict -import six -from six.moves import urllib - import attr from signedjson.key import ( decode_verify_key_bytes, @@ -59,7 +57,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, cmp=False) -class VerifyJsonRequest(object): +class VerifyJsonRequest: """ A request to verify a JSON object. @@ -98,7 +96,7 @@ class KeyLookupError(ValueError): pass -class Keyring(object): +class Keyring: def __init__(self, hs, key_fetchers=None): self.clock = hs.get_clock() @@ -225,8 +223,7 @@ class Keyring(object): return results - @defer.inlineCallbacks - def _start_key_lookups(self, verify_requests): + async def _start_key_lookups(self, verify_requests): """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. @@ -247,7 +244,7 @@ class Keyring(object): server_to_request_ids.setdefault(server_name, set()).add(request_id) # Wait for any previous lookups to complete before proceeding. - yield self.wait_for_previous_lookups(server_to_request_ids.keys()) + await self.wait_for_previous_lookups(server_to_request_ids.keys()) # take out a lock on each of the servers by sticking a Deferred in # key_downloads @@ -285,15 +282,14 @@ class Keyring(object): except Exception: logger.exception("Error starting key lookups") - @defer.inlineCallbacks - def wait_for_previous_lookups(self, server_names): + async def wait_for_previous_lookups(self, server_names) -> None: """Waits for any previous key lookups for the given servers to finish. Args: server_names (Iterable[str]): list of servers which we want to look up Returns: - Deferred[None]: resolves once all key lookups for the given servers have + Resolves once all key lookups for the given servers have completed. Follows the synapse rules of logcontext preservation. """ loop_count = 1 @@ -311,7 +307,7 @@ class Keyring(object): loop_count, ) with PreserveLoggingContext(): - yield defer.DeferredList((w[1] for w in wait_on)) + await defer.DeferredList((w[1] for w in wait_on)) loop_count += 1 @@ -328,44 +324,44 @@ class Keyring(object): remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} - @defer.inlineCallbacks - def do_iterations(): - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - yield self._attempt_key_fetches_with_fetcher(f, remaining_requests) + async def do_iterations(): + try: + with Measure(self.clock, "get_server_verify_keys"): + for f in self._key_fetchers: + if not remaining_requests: + return + await self._attempt_key_fetches_with_fetcher( + f, remaining_requests + ) - # look for any requests which weren't satisfied + # look for any requests which weren't satisfied + with PreserveLoggingContext(): + for verify_request in remaining_requests: + verify_request.key_ready.errback( + SynapseError( + 401, + "No key for %s with ids in %s (min_validity %i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, + ), + Codes.UNAUTHORIZED, + ) + ) + except Exception as err: + # we don't really expect to get here, because any errors should already + # have been caught and logged. But if we do, let's log the error and make + # sure that all of the deferreds are resolved. + logger.error("Unexpected error in _get_server_verify_keys: %s", err) with PreserveLoggingContext(): for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) - ) + if not verify_request.key_ready.called: + verify_request.key_ready.errback(err) - def on_err(err): - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) - - run_in_background(do_iterations).addErrback(on_err) + run_in_background(do_iterations) - @defer.inlineCallbacks - def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): """Use a key fetcher to attempt to satisfy some key requests Args: @@ -392,7 +388,7 @@ class Keyring(object): verify_request.minimum_valid_until_ts, ) - results = yield fetcher.get_keys(missing_keys) + results = await fetcher.get_keys(missing_keys) completed = [] for verify_request in remaining_requests: @@ -424,8 +420,8 @@ class Keyring(object): remaining_requests.difference_update(completed) -class KeyFetcher(object): - def get_keys(self, keys_to_fetch): +class KeyFetcher: + async def get_keys(self, keys_to_fetch): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -444,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """see KeyFetcher.get_keys""" keys_to_fetch = ( @@ -454,20 +449,19 @@ class StoreKeyFetcher(KeyFetcher): for key_id in keys_for_server.keys() ) - res = yield self.store.get_server_verify_keys(keys_to_fetch) + res = await self.store.get_server_verify_keys(keys_to_fetch) keys = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys -class BaseV2KeyFetcher(object): +class BaseV2KeyFetcher: def __init__(self, hs): self.store = hs.get_datastore() self.config = hs.get_config() - @defer.inlineCallbacks - def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response(self, from_server, response_json, time_added_ms): """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -539,7 +533,7 @@ class BaseV2KeyFetcher(object): key_json_bytes = encode_canonical_json(response_json) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -569,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_http_client() self.key_servers = self.config.key_servers - @defer.inlineCallbacks - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """see KeyFetcher.get_keys""" - @defer.inlineCallbacks - def get_key(key_server): + async def get_key(key_server): try: - result = yield self.get_server_verify_key_v2_indirect( + result = await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) return result @@ -594,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return {} - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.gatherResults( [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, @@ -608,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return union_of_keys - @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -619,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): the keys Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map from server_name -> key_id -> FetchKeyResult Raises: @@ -634,7 +625,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - query_response = yield self.client.post_json( + query_response = await self.client.post_json( destination=perspective_name, path="/_matrix/key/v2/query", data={ @@ -661,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") - if not isinstance(server_name, six.string_types): + if not isinstance(server_name, str): raise KeyLookupError( "Malformed response from key notary server %s: invalid server_name" % (perspective_name,) @@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): try: self._validate_perspectives_response(key_server, response) - processed_response = yield self.process_v2_response( + processed_response = await self.process_v2_response( perspective_name, response, time_added_ms=time_now_ms ) except KeyLookupError as e: @@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) keys.setdefault(server_name, {}).update(processed_response) - yield self.store.store_server_verify_keys( + await self.store.store_server_verify_keys( perspective_name, time_now_ms, added_keys ) @@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher): self.clock = hs.get_clock() self.client = hs.get_http_client() - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """ Args: keys_to_fetch (dict[str, iterable[str]]): the keys to be fetched. server_name -> key_ids Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: map from server_name -> key_id -> FetchKeyResult """ results = {} - @defer.inlineCallbacks - def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item): server_name, key_ids = key_to_fetch_item try: - keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids) + keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys except KeyLookupError as e: logger.warning( @@ -767,12 +757,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback( - lambda _: results - ) + await yieldable_gather_results(get_key, keys_to_fetch.items()) + return results - @defer.inlineCallbacks - def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct(self, server_name, key_ids): """ Args: @@ -780,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): key_ids (iterable[str]): Returns: - Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result + dict[str, FetchKeyResult]: map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup @@ -794,7 +782,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): time_now_ms = self.clock.time_msec() try: - response = yield self.client.get_json( + response = await self.client.get_json( destination=server_name, path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), @@ -825,12 +813,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher): % (server_name, response["server_name"]) ) - response_keys = yield self.process_v2_response( + response_keys = await self.process_v2_response( from_server=server_name, response_json=response, time_added_ms=time_now_ms, ) - yield self.store.store_server_verify_keys( + await self.store.store_server_verify_keys( server_name, time_now_ms, ((server_name, key_id, key) for key_id, key in response_keys.items()), @@ -840,22 +828,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -@defer.inlineCallbacks -def _handle_key_deferred(verify_request): +async def _handle_key_deferred(verify_request) -> None: """Waits for the key to become available, and then performs a verification Args: verify_request (VerifyJsonRequest): - Returns: - Deferred[None] - Raises: SynapseError if there was a problem performing the verification """ server_name = verify_request.server_name with PreserveLoggingContext(): - _, key_id, verify_key = yield verify_request.key_ready + _, key_id, verify_key = await verify_request.key_ready json_object = verify_request.json_object diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c582355146..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail @@ -65,14 +65,16 @@ def check( room_id = event.room_id - # I'm not really expecting to get auth events in the wrong room, but let's - # sanity-check it + # We need to ensure that the auth events are actually for the same room, to + # stop people from using powers they've been granted in other rooms for + # example. for auth_event in auth_events.values(): if auth_event.room_id != room_id: - raise Exception( + raise AuthError( + 403, "During auth for event %s in room %s, found event %s in the state " "which is in room %s" - % (event.event_id, room_id, auth_event.event_id, auth_event.room_id) + % (event.event_id, room_id, auth_event.event_id, auth_event.room_id), ) if do_sig_check: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 533ba327f5..bf800a3852 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,9 +18,7 @@ import abc import os from distutils.util import strtobool -from typing import Dict, Optional, Type - -import six +from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -98,7 +96,7 @@ class DefaultDictProperty(DictProperty): return instance._dict.get(self.key, self.default) -class _EventInternalMetadata(object): +class _EventInternalMetadata: __slots__ = ["_dict"] def __init__(self, internal_metadata_dict: JsonDict): @@ -122,7 +120,7 @@ class _EventInternalMetadata(object): # be here before = DictProperty("before") # type: str after = DictProperty("after") # type: str - order = DictProperty("order") # type: int + order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: return dict(self._dict) @@ -135,6 +133,8 @@ class _EventInternalMetadata(object): rejection. This is needed as those events are marked as outliers, but they still need to be processed as if they're new events (e.g. updating invite state in the database, relaying to clients, etc). + + (Added in synapse 0.99.0, so may be unreliable for events received before that) """ return self._dict.get("out_of_band_membership", False) @@ -290,7 +290,7 @@ class EventBase(metaclass=abc.ABCMeta): return list(self._dict.items()) def keys(self): - return six.iterkeys(self._dict) + return self._dict.keys() def prev_event_ids(self): """Returns the list of prev event IDs. The order matches the order diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a0c4a40c27..b6c47be646 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey -from twisted.internet import defer - +from synapse.api.auth import Auth from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -29,13 +28,15 @@ from synapse.api.room_versions import ( ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @attr.s(slots=True, cmp=False, frozen=True) -class EventBuilder(object): +class EventBuilder: """A format independent event builder used to build up the event content before signing the event. @@ -44,45 +45,46 @@ class EventBuilder(object): Attributes: room_version: Version of the target room - room_id (str) - type (str) - sender (str) - content (dict) - unsigned (dict) - internal_metadata (_EventInternalMetadata) - - _state (StateHandler) - _auth (synapse.api.Auth) - _store (DataStore) - _clock (Clock) - _hostname (str): The hostname of the server creating the event + room_id + type + sender + content + unsigned + internal_metadata + + _state + _auth + _store + _clock + _hostname: The hostname of the server creating the event _signing_key: The signing key to use to sign the event as the server """ - _state = attr.ib() - _auth = attr.ib() - _store = attr.ib() - _clock = attr.ib() - _hostname = attr.ib() - _signing_key = attr.ib() + _state = attr.ib(type=StateHandler) + _auth = attr.ib(type=Auth) + _store = attr.ib(type=DataStore) + _clock = attr.ib(type=Clock) + _hostname = attr.ib(type=str) + _signing_key = attr.ib(type=SigningKey) room_version = attr.ib(type=RoomVersion) - room_id = attr.ib() - type = attr.ib() - sender = attr.ib() + room_id = attr.ib(type=str) + type = attr.ib(type=str) + sender = attr.ib(type=str) - content = attr.ib(default=attr.Factory(dict)) - unsigned = attr.ib(default=attr.Factory(dict)) + content = attr.ib(default=attr.Factory(dict), type=JsonDict) + unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict) # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key = attr.ib(default=None) - _redacts = attr.ib(default=None) - _origin_server_ts = attr.ib(default=None) + _state_key = attr.ib(default=None, type=Optional[str]) + _redacts = attr.ib(default=None, type=Optional[str]) + _origin_server_ts = attr.ib(default=None, type=Optional[int]) internal_metadata = attr.ib( - default=attr.Factory(lambda: _EventInternalMetadata({})) + default=attr.Factory(lambda: _EventInternalMetadata({})), + type=_EventInternalMetadata, ) @property @@ -95,31 +97,35 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - @defer.inlineCallbacks - def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - Deferred[FrozenEvent] + The signed and hashed event. """ - state_ids = yield self._state.get_current_state_ids( + state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_ids = yield self._auth.compute_auth_events(self, state_ids) + auth_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = yield self._store.add_event_hashes(auth_ids) - prev_events = yield self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of(prev_event_ids) + old_depth = await self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not @@ -137,7 +143,7 @@ class EventBuilder(object): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key @@ -158,11 +164,11 @@ class EventBuilder(object): ) -class EventBuilderFactory(object): +class EventBuilderFactory: def __init__(self, hs): self.clock = hs.get_clock() self.hostname = hs.hostname - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.store = hs.get_datastore() self.state = hs.get_state_handler() diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 7c5f620d09..afecafe15c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,19 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union - -from six import iteritems +from typing import TYPE_CHECKING, Optional, Union import attr from frozendict import frozendict -from twisted.internet import defer - from synapse.appservice import ApplicationService +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import StateMap +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + @attr.s(slots=True) class EventContext: @@ -131,8 +131,7 @@ class EventContext: delta_ids=delta_ids, ) - @defer.inlineCallbacks - def serialize(self, event, store): + async def serialize(self, event: EventBase, store: "DataStore") -> dict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -148,7 +147,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -216,8 +215,7 @@ class EventContext: return self._state_group - @defer.inlineCallbacks - def get_current_state_ids(self): + async def get_current_state_ids(self) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -226,32 +224,31 @@ class EventContext: ``rejected`` is set. Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + Returns None if state_group is None, which happens when the associated + event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched() + await self._ensure_fetched() return self._current_state_ids - @defer.inlineCallbacks - def get_prev_state_ids(self): + async def get_prev_state_ids(self): """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group + dict[(str, str), str]|None: Returns None if state_group is None, which happens when the associated event is an outlier. Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched() + await self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -271,8 +268,8 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self): - return defer.succeed(None) + async def _ensure_fetched(self): + return None @attr.s(slots=True) @@ -305,21 +302,20 @@ class _AsyncEventContextImpl(EventContext): _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self): + async def _ensure_fetched(self): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return make_deferred_yieldable(self._fetching_state_deferred) + return await make_deferred_yieldable(self._fetching_state_deferred) - @defer.inlineCallbacks - def _fill_out_state(self): + async def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self._current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) if self._event_state_key is not None: @@ -341,7 +337,7 @@ def _encode_state_dict(state_dict): if state_dict is None: return None - return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] + return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()] def _decode_state_dict(input): diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..b0fc859a47 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,16 +15,17 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: import synapse.server -class SpamChecker(object): +class SpamChecker: def __init__(self, hs: "synapse.server.HomeServer"): self.spam_checkers = [] # type: List[Any] @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 459132d388..9d5310851c 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Requester -class ThirdPartyEventRules(object): +class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra set of rules to apply when processing events. @@ -39,76 +41,79 @@ class ThirdPartyEventRules(object): config=config, http_client=hs.get_simple_http_client() ) - @defer.inlineCallbacks - def check_event_allowed(self, event, context): + async def check_event_allowed( + self, event: EventBase, context: EventContext + ) -> bool: """Check if a provided event should be allowed in the given context. Args: - event (synapse.events.EventBase): The event to be checked. - context (synapse.events.snapshot.EventContext): The context of the event. + event: The event to be checked. + context: The context of the event. Returns: - defer.Deferred[bool]: True if the event should be allowed, False if not. + True if the event should be allowed, False if not. """ if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} for key, event_id in prev_state_ids.items(): - state_events[key] = yield self.store.get_event(event_id, allow_none=True) + state_events[key] = await self.store.get_event(event_id, allow_none=True) - ret = yield self.third_party_rules.check_event_allowed(event, state_events) + ret = await self.third_party_rules.check_event_allowed(event, state_events) return ret - @defer.inlineCallbacks - def on_create_room(self, requester, config, is_requester_admin): + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ) -> bool: """Intercept requests to create room to allow, deny or update the request config. Args: - requester (Requester) - config (dict): The creation config from the client. - is_requester_admin (bool): If the requester is an admin + requester + config: The creation config from the client. + is_requester_admin: If the requester is an admin Returns: - defer.Deferred[bool]: Whether room creation is allowed or denied. + Whether room creation is allowed or denied. """ if self.third_party_rules is None: return True - ret = yield self.third_party_rules.on_create_room( + ret = await self.third_party_rules.on_create_room( requester, config, is_requester_admin ) return ret - @defer.inlineCallbacks - def check_threepid_can_be_invited(self, medium, address, room_id): + async def check_threepid_can_be_invited( + self, medium: str, address: str, room_id: str + ) -> bool: """Check if a provided 3PID can be invited in the given room. Args: - medium (str): The 3PID's medium. - address (str): The 3PID's address. - room_id (str): The room we want to invite the threepid to. + medium: The 3PID's medium. + address: The 3PID's address. + room_id: The room we want to invite the threepid to. Returns: - defer.Deferred[bool], True if the 3PID can be invited, False if not. + True if the 3PID can be invited, False if not. """ if self.third_party_rules is None: return True - state_ids = yield self.store.get_filtered_current_state_ids(room_id) - room_state_events = yield self.store.get_events(state_ids.values()) + state_ids = await self.store.get_filtered_current_state_ids(room_id) + room_state_events = await self.store.get_events(state_ids.values()) state_events = {} for key, event_id in state_ids.items(): state_events[key] = room_state_events[event_id] - ret = yield self.third_party_rules.check_threepid_can_be_invited( + ret = await self.third_party_rules.check_threepid_can_be_invited( medium, address, state_events ) return ret diff --git a/synapse/events/utils.py b/synapse/events/utils.py index dd340be9a7..32c73d3413 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -12,16 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +import collections.abc import re from typing import Any, Mapping, Union -from six import string_types - from frozendict import frozendict -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion @@ -318,7 +314,7 @@ def serialize_event( if only_event_fields: if not isinstance(only_event_fields, list) or not all( - isinstance(f, string_types) for f in only_event_fields + isinstance(f, str) for f in only_event_fields ): raise TypeError("only_event_fields must be a list of strings") d = only_fields(d, only_event_fields) @@ -326,7 +322,7 @@ def serialize_event( return d -class EventClientSerializer(object): +class EventClientSerializer: """Serializes events that are to be sent to clients. This is used for bundling extra information with any events to be sent to @@ -339,8 +335,9 @@ class EventClientSerializer(object): hs.config.experimental_msc1849_support_enabled ) - @defer.inlineCallbacks - def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): + async def serialize_event( + self, event, time_now, bundle_aggregations=True, **kwargs + ): """Serializes a single event. Args: @@ -350,7 +347,7 @@ class EventClientSerializer(object): **kwargs: Arguments to pass to `serialize_event` Returns: - Deferred[dict]: The serialized event + dict: The serialized event """ # To handle the case of presence events and the like if not isinstance(event, EventBase): @@ -365,8 +362,8 @@ class EventClientSerializer(object): if not event.internal_metadata.is_redacted() and ( self.experimental_msc1849_support_enabled and bundle_aggregations ): - annotations = yield self.store.get_aggregation_groups_for_event(event_id) - references = yield self.store.get_relations_for_event( + annotations = await self.store.get_aggregation_groups_for_event(event_id) + references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) @@ -380,7 +377,7 @@ class EventClientSerializer(object): edit = None if event.type == EventTypes.Message: - edit = yield self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id) if edit: # If there is an edit replace the content, preserving existing @@ -426,7 +423,7 @@ def copy_power_levels_contents( Raises: TypeError if the input does not look like a valid power levels event content """ - if not isinstance(old_power_levels, collections.Mapping): + if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) power_levels = {} @@ -436,7 +433,7 @@ def copy_power_levels_contents( power_levels[k] = v continue - if isinstance(v, collections.Mapping): + if isinstance(v, collections.abc.Mapping): power_levels[k] = h = {} for k1, v1 in v.items(): # we should only have one level of nesting diff --git a/synapse/events/validator.py b/synapse/events/validator.py index b001c64bb4..9df35b54ba 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import integer_types, string_types - from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions @@ -22,7 +20,7 @@ from synapse.events.utils import validate_canonicaljson from synapse.types import EventID, RoomID, UserID -class EventValidator(object): +class EventValidator: def validate_new(self, event, config): """Validates the event has roughly the right format @@ -53,7 +51,7 @@ class EventValidator(object): event_strings = ["origin"] for s in event_strings: - if not isinstance(getattr(event, s), string_types): + if not isinstance(getattr(event, s), str): raise SynapseError(400, "'%s' not a string type" % (s,)) # Depending on the room version, ensure the data is spec compliant JSON. @@ -76,87 +74,34 @@ class EventValidator(object): ) if event.type == EventTypes.Retention: - self._validate_retention(event, config) + self._validate_retention(event) - def _validate_retention(self, event, config): + def _validate_retention(self, event): """Checks that an event that defines the retention policy for a room respects the - boundaries imposed by the server's administrator. + format enforced by the spec. Args: event (FrozenEvent): The event to validate. - config (Config): The homeserver's configuration. """ min_lifetime = event.content.get("min_lifetime") max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if not isinstance(min_lifetime, integer_types): + if not isinstance(min_lifetime, int): raise SynapseError( code=400, msg="'min_lifetime' must be an integer", errcode=Codes.BAD_JSON, ) - if ( - config.retention_allowed_lifetime_min is not None - and min_lifetime < config.retention_allowed_lifetime_min - ): - raise SynapseError( - code=400, - msg=( - "'min_lifetime' can't be lower than the minimum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - - if ( - config.retention_allowed_lifetime_max is not None - and min_lifetime > config.retention_allowed_lifetime_max - ): - raise SynapseError( - code=400, - msg=( - "'min_lifetime' can't be greater than the maximum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - if max_lifetime is not None: - if not isinstance(max_lifetime, integer_types): + if not isinstance(max_lifetime, int): raise SynapseError( code=400, msg="'max_lifetime' must be an integer", errcode=Codes.BAD_JSON, ) - if ( - config.retention_allowed_lifetime_min is not None - and max_lifetime < config.retention_allowed_lifetime_min - ): - raise SynapseError( - code=400, - msg=( - "'max_lifetime' can't be lower than the minimum allowed value" - " enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - - if ( - config.retention_allowed_lifetime_max is not None - and max_lifetime > config.retention_allowed_lifetime_max - ): - raise SynapseError( - code=400, - msg=( - "'max_lifetime' can't be greater than the maximum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - if ( min_lifetime is not None and max_lifetime is not None @@ -183,7 +128,7 @@ class EventValidator(object): strings.append("state_key") for s in strings: - if not isinstance(getattr(event, s), string_types): + if not isinstance(getattr(event, s), str): raise SynapseError(400, "Not '%s' a string type" % (s,)) RoomID.from_string(event.room_id) @@ -223,7 +168,7 @@ class EventValidator(object): for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) - if not isinstance(d[s], string_types): + if not isinstance(d[s], str): raise SynapseError(400, "'%s' not a string type" % (s,)) def _ensure_state_event(self, event): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c0012c6872..38aa47963f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Iterable, List -import six - from twisted.internet import defer from twisted.internet.defer import Deferred, DeferredList from twisted.python.failure import Failure @@ -41,7 +39,7 @@ from synapse.types import JsonDict, get_domain_from_id logger = logging.getLogger(__name__) -class FederationBase(object): +class FederationBase: def __init__(self, hs): self.hs = hs @@ -93,8 +91,8 @@ class FederationBase(object): # *actual* redacted copy to be on the safe side.) redacted_event = prune_event(pdu) if set(redacted_event.keys()) == set(pdu.keys()) and set( - six.iterkeys(redacted_event.content) - ) == set(six.iterkeys(pdu.content)): + redacted_event.content.keys() + ) == set(pdu.content.keys()): logger.info( "Event %s seems to have been redacted; using our redacted " "copy", @@ -294,7 +292,7 @@ def event_from_pdu_json( assert_params_in_dict(pdu_json, ("type", "depth")) depth = pdu_json["depth"] - if not isinstance(depth, six.integer_types): + if not isinstance(depth, int): raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 687cd841ac..38ac7ec699 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -87,7 +87,7 @@ class FederationClient(FederationBase): self.transport_layer = hs.get_federation_transport_client() self.hostname = hs.hostname - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self._get_pdu_cache = ExpiringCache( cache_name="get_pdu_cache", @@ -135,7 +135,7 @@ class FederationClient(FederationBase): and try the request anyway. Returns: - a Deferred which will eventually yield a JSON object from the + a Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels(query_type).inc() @@ -157,7 +157,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() @@ -180,7 +180,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() @@ -245,7 +245,7 @@ class FederationClient(FederationBase): event_id: event to fetch room_version: version of the room outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part + it's from an arbitrary point in the context as opposed to part of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -351,7 +351,7 @@ class FederationClient(FederationBase): outlier: bool = False, include_none: bool = False, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashs of each + """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of that PDU. @@ -374,29 +374,26 @@ class FederationClient(FederationBase): """ deferreds = self._check_sigs_and_hashes(room_version, pdus) - @defer.inlineCallbacks - def handle_check_result(pdu: EventBase, deferred: Deferred): + async def handle_check_result(pdu: EventBase, deferred: Deferred): try: - res = yield make_deferred_yieldable(deferred) + res = await make_deferred_yieldable(deferred) except SynapseError: res = None if not res: # Check local db. - res = yield self.store.get_event( + res = await self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: try: - res = yield defer.ensureDeferred( - self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) + res = await self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, ) except SynapseError: pass @@ -903,7 +900,7 @@ class FederationClient(FederationBase): party instance Returns: - Deferred[Dict[str, Any]]: The response from the remote server, or None if + Awaitable[Dict[str, Any]]: The response from the remote server, or None if `remote_server` is the same as the local server_name Raises: @@ -995,24 +992,25 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") - @defer.inlineCallbacks - def get_room_complexity(self, destination, room_id): + async def get_room_complexity( + self, destination: str, room_id: str + ) -> Optional[dict]: """ Fetch the complexity of a remote room from another server. Args: - destination (str): The remote server - room_id (str): The room ID to ask about. + destination: The remote server + room_id: The room ID to ask about. Returns: - Deferred[dict] or Deferred[None]: Dict contains the complexity - metric versions, while None means we could not fetch the complexity. + Dict contains the complexity metric versions, while None means we + could not fetch the complexity. """ try: - complexity = yield self.transport_layer.get_room_complexity( + complexity = await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - defer.returnValue(complexity) + return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us. @@ -1029,4 +1027,4 @@ class FederationClient(FederationBase): # If we don't manage to find it, return None. It's not an error if a # server doesn't give it to us. - defer.returnValue(None) + return None diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 32a8a2ee46..218df884b0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -15,13 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union - -import six -from six import iteritems +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Match, + Optional, + Tuple, + Union, +) -from canonicaljson import json -from prometheus_client import Counter +from prometheus_client import Counter, Histogram from twisted.internet import defer from twisted.internet.abstract import isIPAddress @@ -55,10 +62,13 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, unwrapFirstError +from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.server import HomeServer + # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. TRANSACTION_CONCURRENCY_LIMIT = 10 @@ -73,6 +83,10 @@ received_queries_counter = Counter( "synapse_federation_server_received_queries", "", ["type"] ) +pdu_process_time = Histogram( + "synapse_federation_server_pdu_process_time", "Time taken to process an event", +) + class FederationServer(FederationBase): def __init__(self, hs): @@ -94,6 +108,9 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_ids_resp_cache = ResponseCache( + hs, "state_ids_resp", timeout_ms=30000 + ) async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int @@ -274,21 +291,22 @@ class FederationServer(FederationBase): for pdu in pdus_by_room[room_id]: event_id = pdu.event_id - with nested_logging_context(event_id): - try: - await self._handle_received_pdu(origin, pdu) - pdu_results[event_id] = {} - except FederationError as e: - logger.warning("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s", - event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), - ) + with pdu_process_time.time(): + with nested_logging_context(event_id): + try: + await self._handle_received_pdu(origin, pdu) + pdu_results[event_id] = {} + except FederationError as e: + logger.warning("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s", + event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), + ) await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT @@ -360,10 +378,16 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") + resp = await self._state_ids_resp_cache.wrap( + (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, + ) + + return 200, resp + + async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) - - return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} + return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( self, room_id: str, event_id: str @@ -524,9 +548,9 @@ class FederationServer(FederationBase): json_result = {} # type: Dict[str, Dict[str, dict]] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json_decoder.decode(json_str) } logger.info( @@ -534,9 +558,9 @@ class FederationServer(FederationBase): ",".join( ( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) + for user_id, user_keys in json_result.items() + for device_id, device_keys in user_keys.items() + for key_id, _ in device_keys.items() ) ), ) @@ -715,7 +739,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warning("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignoring non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -729,7 +753,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warning("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignoring non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -739,7 +763,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warning("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignoring non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -752,7 +776,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: - if not isinstance(acl_entry, six.string_types): + if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) @@ -761,16 +785,35 @@ def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: return regex.match(server_name) -class FederationHandlerRegistry(object): +class FederationHandlerRegistry: """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. """ - def __init__(self): - self.edu_handlers = {} - self.query_handlers = {} + def __init__(self, hs: "HomeServer"): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + self._instance_name = hs.get_instance_name() - def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]): + # These are safe to load in monolith mode, but will explode if we try + # and use them. However we have guards before we use them to ensure that + # we don't route to ourselves, and in monolith mode that will always be + # the case. + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + self.edu_handlers = ( + {} + ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] + self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + + # Map from type to instance name that we should route EDU handling to. + self._edu_type_to_instance = {} # type: Dict[str, str] + + def register_edu_handler( + self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + ): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. @@ -807,66 +850,56 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler + def register_instance_for_edu(self, edu_type: str, instance_name: str): + """Register that the EDU handler is on a different instance than master. + """ + self._edu_type_to_instance[edu_type] = instance_name + async def on_edu(self, edu_type: str, origin: str, content: dict): + if not self.config.use_presence and edu_type == "m.presence": + return + + # Check if we have a handler on this instance handler = self.edu_handlers.get(edu_type) - if not handler: - logger.warning("No handler registered for EDU type %s", edu_type) + if handler: + with start_active_span_from_edu(content, "handle_edu"): + try: + await handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception: + logger.exception("Failed to handle edu %r", edu_type) return - with start_active_span_from_edu(content, "handle_edu"): + # Check if we can route it somewhere else that isn't us + route_to = self._edu_type_to_instance.get(edu_type, "master") + if route_to != self._instance_name: try: - await handler(origin, content) + await self._send_edu( + instance_name=route_to, + edu_type=edu_type, + origin=origin, + content=content, + ) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: logger.exception("Failed to handle edu %r", edu_type) - - def on_query(self, query_type: str, args: dict) -> defer.Deferred: - handler = self.query_handlers.get(query_type) - if not handler: - logger.warning("No handler registered for query type %s", query_type) - raise NotFoundError("No handler for Query type '%s'" % (query_type,)) - - return handler(args) - - -class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): - """A FederationHandlerRegistry for worker processes. - - When receiving EDU or queries it will check if an appropriate handler has - been registered on the worker, if there isn't one then it calls off to the - master process. - """ - - def __init__(self, hs): - self.config = hs.config - self.http_client = hs.get_simple_http_client() - self.clock = hs.get_clock() - - self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) - self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) - - super(ReplicationFederationHandlerRegistry, self).__init__() - - async def on_edu(self, edu_type: str, origin: str, content: dict): - """Overrides FederationHandlerRegistry - """ - if not self.config.use_presence and edu_type == "m.presence": return - handler = self.edu_handlers.get(edu_type) - if handler: - return await super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content - ) - - return await self._send_edu(edu_type=edu_type, origin=origin, content=content) + # Oh well, let's just log and move on. + logger.warning("No handler registered for EDU type %s", edu_type) async def on_query(self, query_type: str, args: dict): - """Overrides FederationHandlerRegistry - """ handler = self.query_handlers.get(query_type) if handler: return await handler(args) - return await self._get_query_client(query_type=query_type, args=args) + # Check if we can route it somewhere else that isn't us + if self._instance_name == "master": + return await self._get_query_client(query_type=query_type, args=args) + + # Uh oh, no handler! Let's raise an exception so the request returns an + # error. + logger.warning("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index d68b4bd670..079e2b2fe0 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,13 +20,16 @@ These actions are mostly only used by the :py:mod:`.replication` module. """ import logging +from typing import Optional, Tuple +from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) -class TransactionActions(object): +class TransactionActions: """ Defines persistence actions that relate to handling Transactions. """ @@ -34,30 +37,32 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, origin, transaction): - """ Have we already responded to a transaction with the same id and + async def have_responded( + self, origin: str, transaction: Transaction + ) -> Optional[Tuple[int, JsonDict]]: + """Have we already responded to a transaction with the same id and origin? Returns: - Deferred: Results in `None` if we have not previously responded to - this transaction or a 2-tuple of `(int, dict)` representing the - response code and response body. + `None` if we have not previously responded to this transaction or a + 2-tuple of `(int, dict)` representing the response code and response body. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.get_received_txn_response(transaction.transaction_id, origin) + return await self.store.get_received_txn_response(transaction_id, origin) @log_function - def set_response(self, origin, transaction, code, response): - """ Persist how we responded to a transaction. - - Returns: - Deferred + async def set_response( + self, origin: str, transaction: Transaction, code: int, response: JsonDict + ) -> None: + """Persist how we responded to a transaction. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.set_received_txn_response( - transaction.transaction_id, origin, code, response + await self.store.set_received_txn_response( + transaction_id, origin, code, response ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 52f4f54215..8e46957d15 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -33,14 +33,12 @@ import logging from collections import namedtuple from typing import Dict, List, Tuple, Type -from six import iteritems - from sortedcontainers import SortedDict from twisted.internet import defer +from synapse.api.presence import UserPresenceState from synapse.metrics import LaterGauge -from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure from .units import Edu @@ -48,7 +46,7 @@ from .units import Edu logger = logging.getLogger(__name__) -class FederationRemoteSendQueue(object): +class FederationRemoteSendQueue: """A drop in replacement for FederationSender""" def __init__(self, hs): @@ -57,6 +55,11 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id + # We may have multiple federation sender instances, so we need to track + # their positions separately. + self._sender_instances = hs.config.worker.federation_shard_config.instances + self._sender_positions = {} + # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] @@ -263,7 +266,14 @@ class FederationRemoteSendQueue(object): def get_current_token(self): return self.pos - 1 - def federation_ack(self, token): + def federation_ack(self, instance_name, token): + if self._sender_instances: + # If we have configured multiple federation sender instances we need + # to track their positions separately, and only clear the queue up + # to the token all instances have acked. + self._sender_positions[instance_name] = token + token = min(self._sender_positions.values()) + self._clear_queue_before_pos(token) async def get_replication_rows( @@ -327,7 +337,7 @@ class FederationRemoteSendQueue(object): # stream position. keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in iteritems(keyed_edus): + for ((destination, edu_key), pos) in keyed_edus.items(): rows.append( ( pos, @@ -355,13 +365,13 @@ class FederationRemoteSendQueue(object): ) -class BaseFederationRow(object): +class BaseFederationRow: """Base class for rows to be sent in the federation stream. Specifies how to identify, serialize and deserialize the different types. """ - TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. @staticmethod def from_data(data): @@ -530,10 +540,10 @@ def process_rows_for_federation(transaction_queue, rows): states=[state], destinations=destinations ) - for destination, edu_map in iteritems(buff.keyed_edus): + for destination, edu_map in buff.keyed_edus.items(): for key, edu in edu_map.items(): transaction_queue.send_edu(edu, key) - for destination, edu_list in iteritems(buff.edus): + for destination, edu_list in buff.edus.items(): for edu in edu_list: transaction_queue.send_edu(edu, None) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index d473576902..552519e82c 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -16,14 +16,13 @@ import logging from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple -from six import itervalues - from prometheus_client import Counter from twisted.internet import defer import synapse import synapse.metrics +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -41,7 +40,6 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func @@ -58,7 +56,7 @@ sent_pdus_destination_dist_total = Counter( ) -class FederationSender(object): +class FederationSender: def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -71,6 +69,9 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.worker.federation_shard_config + # map from destination to PerDestinationQueue self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] @@ -107,8 +108,6 @@ class FederationSender(object): ), ) - self._order = 1 - self._is_processing = False self._last_poked_id = -1 @@ -193,7 +192,13 @@ class FederationSender(object): ) return - destinations = set(destinations) + destinations = { + d + for d in destinations + if self._federation_shard_config.should_handle( + self._instance_name, d + ) + } if send_on_behalf_of is not None: # If we are sending the event on behalf of another server @@ -203,7 +208,15 @@ class FederationSender(object): logger.debug("Sending %s to %r", event, destinations) - self._send_pdu(event, destinations) + if destinations: + self._send_pdu(event, destinations) + + now = self.clock.time_msec() + ts = await self.store.get_received_ts(event.event_id) + + synapse.metrics.event_processing_lag_by_event.labels( + "federation_sender" + ).observe((now - ts) / 1000) async def handle_room_events(events: Iterable[EventBase]) -> None: with Measure(self.clock, "handle_room_events"): @@ -218,7 +231,7 @@ class FederationSender(object): defer.gatherResults( [ run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) + for evs in events_by_room.values() ], consumeErrors=True, ) @@ -257,9 +270,6 @@ class FederationSender(object): # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. - order = self._order - self._order += 1 - destinations = set(destinations) destinations.discard(self.server_name) logger.debug("Sending to: %s", str(destinations)) @@ -271,10 +281,9 @@ class FederationSender(object): sent_pdus_destination_dist_count.inc() for destination in destinations: - self._get_per_destination_queue(destination).send_pdu(pdu, order) + self._get_per_destination_queue(destination).send_pdu(pdu) - @defer.inlineCallbacks - def send_read_receipt(self, receipt: ReadReceipt): + async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room Args: @@ -315,8 +324,13 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield self.state.get_current_hosts_in_room(room_id) - domains = [d for d in domains if d != self.server_name] + domains_set = await self.state.get_current_hosts_in_room(room_id) + domains = [ + d + for d in domains_set + if d != self.server_name + and self._federation_shard_config.should_handle(self._instance_name, d) + ] if not domains: return @@ -365,8 +379,7 @@ class FederationSender(object): queue.flush_read_receipts_for_room(room_id) @preserve_fn # the caller should not yield on this - @defer.inlineCallbacks - def send_presence(self, states: List[UserPresenceState]): + async def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and @@ -401,7 +414,7 @@ class FederationSender(object): if not states_map: break - yield self._process_presence_inner(list(states_map.values())) + await self._process_presence_inner(list(states_map.values())) except Exception: logger.exception("Error sending presence states to servers") finally: @@ -421,20 +434,29 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + continue self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") - @defer.inlineCallbacks - def _process_presence_inner(self, states: List[UserPresenceState]): + async def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = yield get_interested_remotes(self.store, states, self.state) + hosts_and_states = await get_interested_remotes(self.store, states, self.state) for destinations, states in hosts_and_states: for destination in destinations: if destination == self.server_name: continue + + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + continue + self._get_per_destination_queue(destination).send_presence(states) def build_and_send_edu( @@ -456,6 +478,11 @@ class FederationSender(object): logger.info("Not sending EDU to ourselves") return + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + return + edu = Edu( origin=self.server_name, destination=destination, @@ -472,6 +499,11 @@ class FederationSender(object): edu: edu to send key: clobbering key for this edu """ + if not self._federation_shard_config.should_handle( + self._instance_name, edu.destination + ): + return + queue = self._get_per_destination_queue(edu.destination) if key: queue.send_keyed_edu(edu, key) @@ -483,6 +515,11 @@ class FederationSender(object): logger.warning("Not sending device update to ourselves") return + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() def wake_destination(self, destination: str): @@ -496,6 +533,11 @@ class FederationSender(object): logger.warning("Not waking up ourselves") return + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() @staticmethod diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4e698981a4..defc228c23 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -24,12 +24,12 @@ from synapse.api.errors import ( HttpResponseException, RequestSendFailed, ) +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter @@ -53,7 +53,7 @@ sent_edus_by_type = Counter( ) -class PerDestinationQueue(object): +class PerDestinationQueue: """ Manages the per-destination transmission queues. @@ -74,12 +74,26 @@ class PerDestinationQueue(object): self._clock = hs.get_clock() self._store = hs.get_datastore() self._transaction_manager = transaction_manager + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.worker.federation_shard_config + + self._should_send_on_this_instance = True + if not self._federation_shard_config.should_handle( + self._instance_name, destination + ): + # We don't raise an exception here to avoid taking out any other + # processing. We have a guard in `attempt_new_transaction` that + # ensure we don't start sending stuff. + logger.error( + "Create a per destination queue for %s on wrong worker", destination, + ) + self._should_send_on_this_instance = False self._destination = destination self.transmission_loop_running = False - # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: List[Tuple[EventBase, int]] + # a list of pending PDUs + self._pending_pdus = [] # type: List[EventBase] # XXX this is never actually used: see # https://github.com/matrix-org/synapse/issues/7549 @@ -118,18 +132,17 @@ class PerDestinationQueue(object): + len(self._pending_edus_keyed) ) - def send_pdu(self, pdu: EventBase, order: int) -> None: - """Add a PDU to the queue, and start the transmission loop if neccessary + def send_pdu(self, pdu: EventBase) -> None: + """Add a PDU to the queue, and start the transmission loop if necessary Args: pdu: pdu to send - order """ - self._pending_pdus.append((pdu, order)) + self._pending_pdus.append(pdu) self.attempt_new_transaction() def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if neccessary. + """Add presence updates to the queue. Start the transmission loop if necessary. Args: states: presence to send @@ -171,7 +184,7 @@ class PerDestinationQueue(object): returns immediately. Otherwise kicks off the process of sending a transaction in the background. """ - # list of (pending_pdu, deferred, order) + if self.transmission_loop_running: # XXX: this can get stuck on by a never-ending # request at which point pending_pdus just keeps growing. @@ -180,6 +193,14 @@ class PerDestinationQueue(object): logger.debug("TX [%s] Transaction already in progress", self._destination) return + if not self._should_send_on_this_instance: + # We don't raise an exception here to avoid taking out any other + # processing. + logger.error( + "Trying to start a transaction to %s on wrong worker", self._destination + ) + return + logger.debug("TX [%s] Starting transaction loop", self._destination) run_as_background_process( @@ -188,7 +209,7 @@ class PerDestinationQueue(object): ) async def _transaction_transmission_loop(self) -> None: - pending_pdus = [] # type: List[Tuple[EventBase, int]] + pending_pdus = [] # type: List[EventBase] try: self.transmission_loop_running = True @@ -315,6 +336,28 @@ class PerDestinationQueue(object): (e.retry_last_ts + e.retry_interval) / 1000.0 ), ) + + if e.retry_interval > 60 * 60 * 1000: + # we won't retry for another hour! + # (this suggests a significant outage) + # We drop pending PDUs and EDUs because otherwise they will + # rack up indefinitely. + # Note that: + # - the EDUs that are being dropped here are those that we can + # afford to drop (specifically, only typing notifications, + # read receipts and presence updates are being dropped here) + # - Other EDUs such as to_device messages are queued with a + # different mechanism + # - this is all volatile state that would be lost if the + # federation sender restarted anyway + + # dropping read receipts is a bit sad but should be solved + # through another mechanism, because this is all volatile! + self._pending_pdus = [] + self._pending_edus = [] + self._pending_edus_keyed = {} + self._pending_presence = {} + self._pending_rrs = {} except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -329,13 +372,13 @@ class PerDestinationQueue(object): "TX [%s] Failed to send transaction: %s", self._destination, e ) - for p, _ in pending_pdus: + for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) - for p, _ in pending_pdus: + for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index a2752a54a5..c84072ab73 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -15,8 +15,6 @@ import logging from typing import TYPE_CHECKING, List -from canonicaljson import json - from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -28,6 +26,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.util import json_decoder from synapse.util.metrics import measure_func if TYPE_CHECKING: @@ -36,7 +35,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class TransactionManager(object): +class TransactionManager: """Helper class which handles building and sending transactions shared between PerDestinationQueue objects @@ -54,33 +53,34 @@ class TransactionManager(object): @measure_func("_send_new_transaction") async def send_new_transaction( - self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] - ): + self, destination: str, pdus: List[EventBase], edus: List[Edu], + ) -> bool: + """ + Args: + destination: The destination to send to (e.g. 'example.org') + pdus: In-order list of PDUs to send + edus: List of EDUs to send + + Returns: + True iff the transaction was successful + """ # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is # no active span here, so if the edus were not received by the remote the # span would have no causality and it would be forgotten. - # The span_contexts is a generator so that it won't be evaluated if - # opentracing is disabled. (Yay speed!) span_contexts = [] keep_destination = whitelisted_homeserver(destination) - for edu in pending_edus: + for edu in edus: context = edu.get_context() if context: - span_contexts.append(extract_text_map(json.loads(context))) + span_contexts.append(extract_text_map(json_decoder.decode(context))) if keep_destination: edu.strip_context() with start_active_span_follows_from("send_transaction", span_contexts): - - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[1]) - pdus = [x[0] for x in pending_pdus] - edus = pending_edus - success = True logger.debug("TX [%s] _attempt_new_transaction", destination) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 060bf07197..17a10f622e 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,12 +15,9 @@ # limitations under the License. import logging +import urllib from typing import Any, Dict, Optional -from six.moves import urllib - -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( @@ -33,7 +30,7 @@ from synapse.logging.utils import log_function logger = logging.getLogger(__name__) -class TransportLayerClient(object): +class TransportLayerClient: """Sends federation HTTP requests to other servers""" def __init__(self, hs): @@ -52,7 +49,7 @@ class TransportLayerClient(object): event_id (str): The event we want the context at. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) @@ -76,7 +73,7 @@ class TransportLayerClient(object): giving up. None indicates no timeout. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) @@ -97,7 +94,7 @@ class TransportLayerClient(object): limit (int) Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", @@ -119,16 +116,15 @@ class TransportLayerClient(object): destination, path=path, args=args, try_trailing_slash_on_400=True ) - @defer.inlineCallbacks @log_function - def send_transaction(self, transaction, json_data_callback=None): + async def send_transaction(self, transaction, json_data_callback=None): """ Sends the given Transaction to its destination Args: transaction (Transaction) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Fails with ``HTTPRequestException`` if we get an HTTP response @@ -155,7 +151,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send/%s", transaction.transaction_id) - response = yield self.client.put_json( + response = await self.client.put_json( transaction.destination, path=path, data=json_data, @@ -167,14 +163,13 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def make_query( + async def make_query( self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False ): path = _create_v1_path("/query/%s", query_type) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=args, @@ -185,9 +180,10 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def make_membership_event(self, destination, room_id, user_id, membership, params): + async def make_membership_event( + self, destination, room_id, user_id, membership, params + ): """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -201,7 +197,7 @@ class TransportLayerClient(object): request. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body (ie, the new event). Fails with ``HTTPRequestException`` if we get an HTTP response @@ -232,7 +228,7 @@ class TransportLayerClient(object): ignore_backoff = True retry_on_dns_fail = True - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=params, @@ -243,34 +239,31 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def send_join_v1(self, destination, room_id, event_id, content): + async def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_join_v2(self, destination, room_id, event_id, content): + async def send_join_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_leave_v1(self, destination, room_id, event_id, content): + async def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -283,12 +276,11 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_leave_v2(self, destination, room_id, event_id, content): + async def send_leave_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -301,31 +293,28 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_invite_v1(self, destination, room_id, event_id, content): + async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def send_invite_v2(self, destination, room_id, event_id, content): + async def send_invite_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -356,7 +345,7 @@ class TransportLayerClient(object): data["filter"] = search_filter try: - response = yield self.client.post_json( + response = await self.client.post_json( destination=remote_server, path=path, data=data, ignore_backoff=True ) except HttpResponseException as e: @@ -382,7 +371,7 @@ class TransportLayerClient(object): args["since"] = [since_token] try: - response = yield self.client.get_json( + response = await self.client.get_json( destination=remote_server, path=path, args=args, ignore_backoff=True ) except HttpResponseException as e: @@ -397,29 +386,26 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, destination, room_id, event_dict): + async def exchange_third_party_invite(self, destination, room_id, event_dict): path = _create_v1_path("/exchange_third_party_invite/%s", room_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=event_dict ) return response - @defer.inlineCallbacks @log_function - def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json(destination=destination, path=path) + content = await self.client.get_json(destination=destination, path=path) return content - @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content, timeout): + async def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -454,14 +440,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/keys/query") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def query_user_devices(self, destination, user_id, timeout): + async def query_user_devices(self, destination, user_id, timeout): """Query the devices for a user id hosted on a remote server. Response: @@ -494,14 +479,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/devices/%s", user_id) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content, timeout): + async def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -533,14 +517,13 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def get_missing_events( + async def get_missing_events( self, destination, room_id, @@ -552,7 +535,7 @@ class TransportLayerClient(object): ): path = _create_v1_path("/get_missing_events/%s", room_id) - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data={ @@ -747,7 +730,7 @@ class TransportLayerClient(object): def remove_user_from_group( self, destination, group_id, requester_user_id, user_id, content ): - """Remove a user fron a group + """Remove a user from a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index af4595498c..9325e0f857 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -20,8 +20,6 @@ import logging import re from typing import Optional, Tuple, Type -from twisted.internet.defer import maybeDeferred - import synapse from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions @@ -102,14 +100,14 @@ class NoAuthenticationError(AuthenticationError): pass -class Authenticator(object): +class Authenticator: def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifer = hs.get_notifier() + self.notifier = hs.get_notifier() self.replication_client = None if hs.config.worker.worker_app: @@ -175,7 +173,7 @@ class Authenticator(object): await self.store.set_destination_retry_timings(origin, None, 0, 0) # Inform the relevant places that the remote server is back up. - self.notifer.notify_remote_server_up(origin) + self.notifier.notify_remote_server_up(origin) if self.replication_client: # If we're on a worker we try and inform master about this. The # replication client doesn't hook into the notifier to avoid @@ -230,7 +228,7 @@ def _parse_auth_header(header_bytes): ) -class BaseFederationServlet(object): +class BaseFederationServlet: """Abstract base class for federation servlet classes. The servlet object should have a PATH attribute which takes the form of a regexp to @@ -340,6 +338,12 @@ class BaseFederationServlet(object): if origin: with ratelimiter.ratelimit(origin) as d: await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return -1, None response = await func( origin, content, request.args, *args, **kwargs ) @@ -361,11 +365,7 @@ class BaseFederationServlet(object): continue server.register_paths( - method, - (pattern,), - self._wrap(code), - self.__class__.__name__, - trace=False, + method, (pattern,), self._wrap(code), self.__class__.__name__, ) @@ -799,12 +799,8 @@ class PublicRoomList(BaseFederationServlet): # zero is a special value which corresponds to no limit. limit = None - data = await maybeDeferred( - self.handler.get_local_public_room_list, - limit, - since_token, - network_tuple=network_tuple, - from_federation=True, + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True ) return 200, data diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6b32e0dcbf..64d98fc8f6 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super(Transaction, self).__init__( - transaction_id=transaction_id, pdus=pdus, **kwargs - ) + super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 27b0c02655..a86b3debc5 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -41,8 +41,6 @@ from typing import Tuple from signedjson.sign import sign_json -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id @@ -62,7 +60,7 @@ DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 -class GroupAttestationSigning(object): +class GroupAttestationSigning: """Creates and verifies group attestations. """ @@ -70,10 +68,11 @@ class GroupAttestationSigning(object): self.keyring = hs.get_keyring() self.clock = hs.get_clock() self.server_name = hs.hostname - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key - @defer.inlineCallbacks - def verify_attestation(self, attestation, group_id, user_id, server_name=None): + async def verify_attestation( + self, attestation, group_id, user_id, server_name=None + ): """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -102,7 +101,7 @@ class GroupAttestationSigning(object): if valid_until_ms < now: raise SynapseError(400, "Attestation expired") - yield self.keyring.verify_json_for_server( + await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) @@ -125,7 +124,7 @@ class GroupAttestationSigning(object): ) -class GroupAttestionRenewer(object): +class GroupAttestionRenewer: """Responsible for sending and receiving attestation updates. """ @@ -142,8 +141,7 @@ class GroupAttestionRenewer(object): self._start_renew_attestations, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def on_renew_attestation(self, group_id, user_id, content): + async def on_renew_attestation(self, group_id, user_id, content): """When a remote updates an attestation """ attestation = content["attestation"] @@ -151,11 +149,11 @@ class GroupAttestionRenewer(object): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): raise SynapseError(400, "Neither user not group are on this server") - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, user_id=user_id, group_id=group_id ) - yield self.store.update_remote_attestion(group_id, user_id, attestation) + await self.store.update_remote_attestion(group_id, user_id, attestation) return {} @@ -172,8 +170,7 @@ class GroupAttestionRenewer(object): now + UPDATE_ATTESTATION_TIME_MS ) - @defer.inlineCallbacks - def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]): group_id, user_id = group_user try: if not self.is_mine_id(group_id): @@ -186,16 +183,16 @@ class GroupAttestionRenewer(object): user_id, group_id, ) - yield self.store.remove_attestation_renewal(group_id, user_id) + await self.store.remove_attestation_renewal(group_id, user_id) return attestation = self.attestations.create_attestation(group_id, user_id) - yield self.transport_client.renew_group_attestation( + await self.transport_client.renew_group_attestation( destination, group_id, user_id, content={"attestation": attestation} ) - yield self.store.update_attestation_renewal( + await self.store.update_attestation_renewal( group_id, user_id, attestation ) except (RequestSendFailed, HttpResponseException) as e: diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 14b3a7e22b..3fd7bedf99 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -17,8 +17,6 @@ import logging -from six import string_types - from synapse.api.errors import Codes, SynapseError from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute @@ -34,7 +32,7 @@ logger = logging.getLogger(__name__) # TODO: Flairs -class GroupsServerWorkerHandler(object): +class GroupsServerWorkerHandler: def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() @@ -43,7 +41,7 @@ class GroupsServerWorkerHandler(object): self.clock = hs.get_clock() self.keyring = hs.get_keyring() self.is_mine_id = hs.is_mine_id - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.server_name = hs.hostname self.attestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() @@ -513,7 +511,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): for keyname in ("name", "avatar_url", "short_description", "long_description"): if keyname in content: value = content[keyname] - if not isinstance(value, string_types): + if not isinstance(value, str): raise SynapseError(400, "%r value is not a string" % (keyname,)) profile[keyname] = value diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 2dd183018a..286f0054be 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -20,7 +20,7 @@ from .identity import IdentityHandler from .search import SearchHandler -class Handlers(object): +class Handlers: """ Deprecated. A collection of handlers. diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 61dc4beafe..0206320e96 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -15,8 +15,8 @@ import logging -from twisted.internet import defer - +import synapse.state +import synapse.storage import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter @@ -25,13 +25,9 @@ from synapse.types import UserID logger = logging.getLogger(__name__) -class BaseHandler(object): +class BaseHandler: """ Common base class for the event handlers. - - Attributes: - store (synapse.storage.DataStore): - state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -39,10 +35,10 @@ class BaseHandler(object): Args: hs (synapse.server.HomeServer): """ - self.store = hs.get_datastore() + self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() + self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler self.distributor = hs.get_distributor() self.clock = hs.get_clock() self.hs = hs @@ -68,8 +64,7 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() - @defer.inlineCallbacks - def ratelimit(self, requester, update=True, is_admin_redaction=False): + async def ratelimit(self, requester, update=True, is_admin_redaction=False): """Ratelimits requests. Args: @@ -101,7 +96,7 @@ class BaseHandler(object): burst_count = self._rc_message.burst_count # Check if there is a per user override in the DB. - override = yield self.store.get_ratelimit_for_user(user_id) + override = await self.store.get_ratelimit_for_user(user_id) if override: # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index a8d3fbc6de..9112a0ab86 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. -class AccountDataEventSource(object): +class AccountDataEventSource: def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 590135d19c..4caf6d591a 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -26,15 +26,10 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils -try: - from synapse.push.mailer import load_jinja2_templates -except ImportError: - load_jinja2_templates = None - logger = logging.getLogger(__name__) -class AccountValidityHandler(object): +class AccountValidityHandler: def __init__(self, hs): self.hs = hs self.config = hs.config @@ -47,9 +42,11 @@ class AccountValidityHandler(object): if ( self._account_validity.enabled and self._account_validity.renew_by_email_enabled - and load_jinja2_templates ): # Don't do email-specific configuration if renewal by email is disabled. + self._template_html = self.config.account_validity_template_html + self._template_text = self.config.account_validity_template_text + try: app_name = self.hs.config.email_app_name @@ -65,17 +62,6 @@ class AccountValidityHandler(object): self._raw_from = email.utils.parseaddr(self._from_string)[1] - self._template_html, self._template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_expiry_template_html, - self.config.email_expiry_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) - # Check the renewal emails to send and send them every 30min. def send_emails(): # run as a background process to make sure that the database transactions diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index a2d7959abe..8476256a59 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -17,7 +17,6 @@ import logging import twisted import twisted.internet.error -from twisted.internet import defer from twisted.web import server, static from twisted.web.resource import Resource @@ -35,14 +34,13 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC --------------------------------------------------------------------------------""" -class AcmeHandler(object): +class AcmeHandler: def __init__(self, hs): self.hs = hs self.reactor = hs.get_reactor() self._acme_domain = hs.config.acme_domain - @defer.inlineCallbacks - def start_listening(self): + async def start_listening(self): from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug @@ -82,18 +80,17 @@ class AcmeHandler(object): self._issuer._registered = False try: - yield self._issuer._ensure_registered() + await self._issuer._ensure_registered() except Exception: logger.error(ACME_REGISTER_FAIL_ERROR) raise - @defer.inlineCallbacks - def provision_certificate(self): + async def provision_certificate(self): logger.warning("Reprovisioning %s", self._acme_domain) try: - yield self._issuer.issue_cert(self._acme_domain) + await self._issuer.issue_cert(self._acme_domain) except Exception: logger.exception("Fail!") raise diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index e1d4224e74..69650ff221 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -78,7 +78,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou @attr.s @implementer(ICertificateStore) -class ErsatzStore(object): +class ErsatzStore: """ A store that only stores in memory. """ diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f3c0aeceb6..918d0e037c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -72,7 +72,7 @@ class AdminHandler(BaseHandler): writer (ExfiltrationWriter) Returns: - defer.Deferred: Resolves when all data for a user has been written. + Resolves when all data for a user has been written. The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in @@ -197,7 +197,7 @@ class AdminHandler(BaseHandler): return writer.finished() -class ExfiltrationWriter(object): +class ExfiltrationWriter: """Interface used to specify how to write exported data. """ diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index fe62f78e67..9d4e87dad6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -15,8 +15,6 @@ import logging -from six import itervalues - from prometheus_client import Counter from twisted.internet import defer @@ -29,7 +27,6 @@ from synapse.metrics import ( event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import log_failure from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -37,7 +34,7 @@ logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") -class ApplicationServicesHandler(object): +class ApplicationServicesHandler: def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id @@ -50,8 +47,7 @@ class ApplicationServicesHandler(object): self.current_max = 0 self.is_processing = False - @defer.inlineCallbacks - def notify_interested_services(self, current_id): + async def notify_interested_services(self, current_id): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any @@ -76,7 +72,7 @@ class ApplicationServicesHandler(object): ( upper_bound, events, - ) = yield self.store.get_new_events_for_appservice( + ) = await self.store.get_new_events_for_appservice( self.current_max, limit ) @@ -87,10 +83,9 @@ class ApplicationServicesHandler(object): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - @defer.inlineCallbacks - def handle_event(event): + async def handle_event(event): # Gather interested services - services = yield self._get_services_for_event(event) + services = await self._get_services_for_event(event) if len(services) == 0: return # no services need notifying @@ -98,16 +93,17 @@ class ApplicationServicesHandler(object): # query API for all services which match that user regex. # This needs to block as these user queries need to be # made BEFORE pushing the event. - yield self._check_user_exists(event.sender) + await self._check_user_exists(event.sender) if event.type == EventTypes.Member: - yield self._check_user_exists(event.state_key) + await self._check_user_exists(event.state_key) if not self.started_scheduler: - def start_scheduler(): - return self.scheduler.start().addErrback( - log_failure, "Application Services Failure" - ) + async def start_scheduler(): + try: + return await self.scheduler.start() + except Exception: + logger.error("Application Services Failure") run_as_background_process("as_scheduler", start_scheduler) self.started_scheduler = True @@ -116,25 +112,30 @@ class ApplicationServicesHandler(object): for service in services: self.scheduler.submit_event_for_as(service, event) - @defer.inlineCallbacks - def handle_room_events(events): + now = self.clock.time_msec() + ts = await self.store.get_received_ts(event.event_id) + synapse.metrics.event_processing_lag_by_event.labels( + "appservice_sender" + ).observe((now - ts) / 1000) + + async def handle_room_events(events): for event in events: - yield handle_event(event) + await handle_event(event) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) + for evs in events_by_room.values() ], consumeErrors=True, ) ) - yield self.store.set_appservice_last_pos(upper_bound) + await self.store.set_appservice_last_pos(upper_bound) now = self.clock.time_msec() - ts = yield self.store.get_received_ts(events[-1].event_id) + ts = await self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_positions.labels( "appservice_sender" @@ -157,8 +158,7 @@ class ApplicationServicesHandler(object): finally: self.is_processing = False - @defer.inlineCallbacks - def query_user_exists(self, user_id): + async def query_user_exists(self, user_id): """Check if any application service knows this user_id exists. Args: @@ -166,15 +166,14 @@ class ApplicationServicesHandler(object): Returns: True if this user exists on at least one application service. """ - user_query_services = yield self._get_services_for_user(user_id=user_id) + user_query_services = self._get_services_for_user(user_id=user_id) for user_service in user_query_services: - is_known_user = yield self.appservice_api.query_user(user_service, user_id) + is_known_user = await self.appservice_api.query_user(user_service, user_id) if is_known_user: return True return False - @defer.inlineCallbacks - def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists(self, room_alias): """Check if an application service knows this room alias exists. Args: @@ -189,19 +188,18 @@ class ApplicationServicesHandler(object): s for s in services if (s.is_interested_in_alias(room_alias_str)) ] for alias_service in alias_query_services: - is_known_alias = yield self.appservice_api.query_alias( + is_known_alias = await self.appservice_api.query_alias( alias_service, room_alias_str ) if is_known_alias: # the alias exists now so don't query more ASes. - result = yield self.store.get_association_from_room_alias(room_alias) + result = await self.store.get_association_from_room_alias(room_alias) return result - @defer.inlineCallbacks - def query_3pe(self, kind, protocol, fields): - services = yield self._get_services_for_3pn(protocol) + async def query_3pe(self, kind, protocol, fields): + services = self._get_services_for_3pn(protocol) - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.DeferredList( [ run_in_background( @@ -220,8 +218,7 @@ class ApplicationServicesHandler(object): return ret - @defer.inlineCallbacks - def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols(self, only_protocol=None): services = self.store.get_app_services() protocols = {} @@ -234,7 +231,7 @@ class ApplicationServicesHandler(object): if p not in protocols: protocols[p] = [] - info = yield self.appservice_api.get_3pe_protocol(s, p) + info = await self.appservice_api.get_3pe_protocol(s, p) if info is not None: protocols[p].append(info) @@ -259,8 +256,7 @@ class ApplicationServicesHandler(object): return protocols - @defer.inlineCallbacks - def _get_services_for_event(self, event): + async def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. Args: @@ -276,7 +272,7 @@ class ApplicationServicesHandler(object): # inside of a list comprehension anymore. interested_list = [] for s in services: - if (yield s.is_interested(event, self.store)): + if await s.is_interested(event, self.store): interested_list.append(s) return interested_list @@ -284,21 +280,20 @@ class ApplicationServicesHandler(object): def _get_services_for_user(self, user_id): services = self.store.get_app_services() interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return defer.succeed(interested_list) + return interested_list def _get_services_for_3pn(self, protocol): services = self.store.get_app_services() interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return defer.succeed(interested_list) + return interested_list - @defer.inlineCallbacks - def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id): if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. return False - user_info = yield self.store.get_user_by_id(user_id) + user_info = await self.store.get_user_by_id(user_id) if user_info: return False @@ -307,10 +302,9 @@ class ApplicationServicesHandler(object): service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - @defer.inlineCallbacks - def _check_user_exists(self, user_id): - unknown_user = yield self._is_unknown_user(user_id) + async def _check_user_exists(self, user_id): + unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = yield self.query_user_exists(user_id) + exists = await self.query_user_exists(user_id) return exists return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 119678e67b..90189869cc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging import time import unicodedata @@ -24,7 +24,6 @@ import attr import bcrypt # type: ignore[import] import pymacaroons -import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( AuthError, @@ -38,19 +37,106 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.http.server import finish_request +from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.push.mailer import load_jinja2_templates -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID +from synapse.util import stringutils as stringutils +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import canonicalise_email from ._base import BaseHandler logger = logging.getLogger(__name__) +def convert_client_dict_legacy_fields_to_identifier( + submission: JsonDict, +) -> Dict[str, str]: + """ + Convert a legacy-formatted login submission to an identifier dict. + + Legacy login submissions (used in both login and user-interactive authentication) + provide user-identifying information at the top-level instead. + + These are now deprecated and replaced with identifiers: + https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types + + Args: + submission: The client dict to convert + + Returns: + The matching identifier dict + + Raises: + SynapseError: If the format of the client dict is invalid + """ + identifier = submission.get("identifier", {}) + + # Generate an m.id.user identifier if "user" parameter is present + user = submission.get("user") + if user: + identifier = {"type": "m.id.user", "user": user} + + # Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present + medium = submission.get("medium") + address = submission.get("address") + if medium and address: + identifier = { + "type": "m.id.thirdparty", + "medium": medium, + "address": address, + } + + # We've converted valid, legacy login submissions to an identifier. If the + # submission still doesn't have an identifier, it's invalid + if not identifier: + raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM) + + # Ensure the identifier has a type + if "type" not in identifier: + raise SynapseError( + 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + ) + + return identifier + + +def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: + """ + Convert a phone login identifier type to a generic threepid identifier. + + Args: + identifier: Login identifier dict of type 'm.id.phone' + + Returns: + An equivalent m.id.thirdparty identifier dict + """ + if "country" not in identifier or ( + # The specification requires a "phone" field, while Synapse used to require a "number" + # field. Accept both for backwards compatibility. + "phone" not in identifier + and "number" not in identifier + ): + raise SynapseError( + 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM + ) + + # Accept both "phone" and "number" as valid keys in m.id.phone + phone_number = identifier.get("phone", identifier["number"]) + + # Convert user-provided phone number to a consistent representation + msisdn = phone_number_to_msisdn(identifier["country"], phone_number) + + return { + "type": "m.id.thirdparty", + "medium": "msisdn", + "address": msisdn, + } + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -131,18 +217,17 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_redirect_confirm.html"], - )[0] + self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_auth_confirm.html"], - )[0] + self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( @@ -161,7 +246,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> dict: + ) -> Tuple[dict, str]: """ Checks that the user is who they claim to be, via a UI auth. @@ -182,9 +267,14 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - The parameters for this request (which may + A tuple of (params, session_id). + + 'params' contains the parameters for this request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + Raises: InteractiveAuthIncompleteError if the client has not yet completed any of the permitted login flows @@ -206,7 +296,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = await self.check_auth( + result, params, session_id = await self.check_ui_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -229,7 +319,7 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") - return params + return params, session_id def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -239,7 +329,7 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - async def check_auth( + async def check_ui_auth( self, flows: List[List[str]], request: SynapseRequest, @@ -297,7 +387,7 @@ class AuthHandler(BaseHandler): # Convert the URI and method to strings. uri = request.uri.decode("utf-8") - method = request.uri.decode("utf-8") + method = request.method.decode("utf-8") # If there's no session ID, create a new session. if not sid: @@ -360,9 +450,17 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session.session_id) + session.session_id, self._auth_dict_for_flows(flows, session.session_id) ) # check auth type currently being presented @@ -409,7 +507,7 @@ class AuthHandler(BaseHandler): ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError(ret) + raise InteractiveAuthIncompleteError(session.session_id, ret) async def add_oob_auth( self, stagetype: str, authdict: Dict[str, Any], clientip: str @@ -863,11 +961,15 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - await provider.on_logged_out( + # This might return an awaitable, if it does block the log out + # until it completes. + result = provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, ) + if inspect.isawaitable(result): + await result # delete pushers associated with this access token if user_info["token_id"] is not None: @@ -928,7 +1030,7 @@ class AuthHandler(BaseHandler): # for the presence of an email address during password reset was # case sensitive). if medium == "email": - address = address.lower() + address = canonicalise_email(address) await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() @@ -956,7 +1058,7 @@ class AuthHandler(BaseHandler): # 'Canonicalise' email addresses as per above if medium == "email": - address = address.lower() + address = canonicalise_email(address) identity_handler = self.hs.get_handlers().identity_handler result = await identity_handler.try_unbind_threepid( @@ -1055,13 +1157,8 @@ class AuthHandler(BaseHandler): ) # Render the HTML and return. - html_bytes = self._sso_auth_success_template.encode("utf-8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) + html = self._sso_auth_success_template + respond_with_html(request, 200, html) async def complete_sso_login( self, @@ -1081,13 +1178,7 @@ class AuthHandler(BaseHandler): # flow. deactivated = await self.store.get_user_deactivated_status(registered_user_id) if deactivated: - html_bytes = self._sso_account_deactivated_template.encode("utf-8") - - request.setResponseCode(403) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - request.write(html_bytes) - finish_request(request) + respond_with_html(request, 403, self._sso_account_deactivated_template) return self._complete_sso_login(registered_user_id, request, client_redirect_url) @@ -1128,17 +1219,12 @@ class AuthHandler(BaseHandler): # URL we redirect users to. redirect_url_no_params = client_redirect_url.split("?")[0] - html_bytes = self._sso_redirect_confirm_template.render( + html = self._sso_redirect_confirm_template.render( display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, - ).encode("utf-8") - - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - request.write(html_bytes) - finish_request(request) + ) + respond_with_html(request, 200, html) @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): @@ -1150,7 +1236,7 @@ class AuthHandler(BaseHandler): @attr.s -class MacaroonGenerator(object): +class MacaroonGenerator: hs = attr.ib() diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 64aaa1335c..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -12,12 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -import xml.etree.ElementTree as ET +import urllib from typing import Dict, Optional, Tuple - -from six.moves import urllib +from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError @@ -37,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -106,7 +105,7 @@ class CasHandler: return user, displayname def _parse_cas_response( - self, cas_response_body: str + self, cas_response_body: bytes ) -> Tuple[str, Dict[str, Optional[str]]]: """ Retrieve the user and other parameters from the CAS response. @@ -212,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 2afb390a92..25169157c1 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process @@ -29,6 +30,7 @@ class DeactivateAccountHandler(BaseHandler): def __init__(self, hs): super(DeactivateAccountHandler, self).__init__(hs) + self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() @@ -40,23 +42,25 @@ class DeactivateAccountHandler(BaseHandler): # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). - hs.get_reactor().callWhenRunning(self._start_user_parting) + if hs.config.worker_app is None: + hs.get_reactor().callWhenRunning(self._start_user_parting) self._account_validity_enabled = hs.config.account_validity.enabled - async def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account( + self, user_id: str, erase_data: bool, id_server: Optional[str] = None + ) -> bool: """Deactivate a user's account Args: - user_id (str): ID of user to be deactivated - erase_data (bool): whether to GDPR-erase the user's data - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to be deactivated + erase_data: whether to GDPR-erase the user's data + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). Returns: - Deferred[bool]: True if identity server supports removing - threepids, otherwise False. + True if identity server supports removing threepids, otherwise False. """ # FIXME: Theoretically there is a race here wherein user resets # password using threepid. @@ -133,11 +137,11 @@ class DeactivateAccountHandler(BaseHandler): return identity_server_supports_unbinding - async def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id: str): """Reject pending invites addressed to a given user ID. Args: - user_id (str): The user ID to reject pending invites for. + user_id: The user ID to reject pending invites for. """ user = UserID.from_string(user_id) pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) @@ -165,22 +169,16 @@ class DeactivateAccountHandler(BaseHandler): room.room_id, ) - def _start_user_parting(self): + def _start_user_parting(self) -> None: """ Start the process that goes through the table of users pending deactivation, if it isn't already running. - - Returns: - None """ if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - async def _user_parter_loop(self): + async def _user_parter_loop(self) -> None: """Loop that parts deactivated users from rooms - - Returns: - None """ self._user_parter_running = True logger.info("Starting user parter") @@ -197,11 +195,8 @@ class DeactivateAccountHandler(BaseHandler): finally: self._user_parter_running = False - async def _part_user(self, user_id): + async def _part_user(self, user_id: str) -> None: """Causes the given user_id to leave all the rooms they're joined to - - Returns: - None """ user = UserID.from_string(user_id) @@ -223,3 +218,31 @@ class DeactivateAccountHandler(BaseHandler): user_id, room_id, ) + + async def activate_account(self, user_id: str) -> None: + """ + Activate an account that was previously deactivated. + + This marks the user as active and not erased in the database, but does + not attempt to rejoin rooms, re-add threepids, etc. + + If enabled, the user will be re-added to the user directory. + + The user will also need a password hash set to actually login. + + Args: + user_id: ID of user to be re-activated + """ + # Add the user to the directory, if necessary. + user = UserID.from_string(user_id) + if self.hs.config.user_directory_search_all_users: + profile = await self.store.get_profileinfo(user.localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + + # Ensure the user is not marked as erased. + await self.store.mark_user_not_erased(user_id) + + # Mark the user as active. + await self.store.set_user_deactivated_status(user_id, False) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 230d170258..643d71a710 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,11 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional - -from six import iteritems, itervalues - -from twisted.internet import defer +from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes @@ -59,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: """ Retrieve the given user's devices Args: - user_id (str): + user_id: The user ID to query for devices. Returns: - defer.Deferred: list[dict[str, X]]: info on each device + info on each device """ set_tag("user_id", user_id) - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -83,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - @defer.inlineCallbacks - def get_device(self, user_id, device_id): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """ Retrieve the given device Args: - user_id (str): - device_id (str): + user_id: The user to get the device from + device_id: The device to fetch. Returns: - defer.Deferred: dict[str, X]: info on the device + info on the device Raises: errors.NotFoundError: if the device was not found """ try: - device = yield self.store.get_device(user_id, device_id) + device = await self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) set_tag("device", device) @@ -108,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler): return device - @measure_func("device.get_user_ids_changed") @trace - @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_token): + @measure_func("device.get_user_ids_changed") + async def get_user_ids_changed(self, user_id, from_token): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. @@ -122,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = yield self.store.get_room_events_max_id() + now_room_key = await self.store.get_room_events_max_id() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -137,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler): # Always tell the user about their own devices tracked_users.add(user_id) - changed = yield self.store.get_users_whose_devices_changed( + changed = await self.store.get_users_whose_devices_changed( from_token.device_list_key, tracked_users ) # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) - member_events = yield self.store.get_membership_changes_for_user( + member_events = await self.store.get_membership_changes_for_user( user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) @@ -154,12 +147,12 @@ class DeviceWorkerHandler(BaseHandler): possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = yield self.store.get_current_state_ids(room_id) + current_state_ids = await self.store.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. if room_id not in room_ids: - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -168,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler): # Fetch the current state at the time. try: - event_ids = yield self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -182,7 +175,7 @@ class DeviceWorkerHandler(BaseHandler): log_kv( {"event": "encountered empty previous state", "room_id": room_id} ) - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -194,14 +187,14 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. - for state_dict in itervalues(prev_state_ids): + for state_dict in prev_state_ids.values(): member_event = state_dict.get((EventTypes.Member, user_id), None) if not member_event or member_event != current_member_id: - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -211,14 +204,14 @@ class DeviceWorkerHandler(BaseHandler): # If there has been any change in membership, include them in the # possibly changed list. We'll check if they are joined below, # and we're not toooo worried about spuriously adding users. - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue # check if this member has changed since any of the extremities # at the stream_ordering, and add them to the list if so. - for state_dict in itervalues(prev_state_ids): + for state_dict in prev_state_ids.values(): prev_event_id = state_dict.get(key, None) if not prev_event_id or prev_event_id != event_id: if state_key != user_id: @@ -240,11 +233,12 @@ class DeviceWorkerHandler(BaseHandler): return result - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") - self_signing_key = yield self.store.get_e2e_cross_signing_key( + async def on_federation_query_user_devices(self, user_id): + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( + user_id + ) + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" ) @@ -273,8 +267,7 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - @defer.inlineCallbacks - def check_device_registered( + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): """ @@ -292,13 +285,13 @@ class DeviceHandler(DeviceWorkerHandler): str: device id (generated if none was supplied) """ if device_id is not None: - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id # if the device id is not specified, we'll autogen one, but loop a few @@ -306,33 +299,29 @@ class DeviceHandler(DeviceWorkerHandler): attempts = 0 while attempts < 5: device_id = stringutils.random_string(10).upper() - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @trace - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """ Delete the given device Args: - user_id (str): - device_id (str): - - Returns: - defer.Deferred: + user_id: The user to delete the device from. + device_id: The device to delete. """ try: - yield self.store.delete_device(user_id, device_id) + await self.store.delete_device(user_id, device_id) except errors.StoreError as e: if e.code == 404: # no match @@ -344,49 +333,40 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) @trace - @defer.inlineCallbacks - def delete_all_devices_for_user(self, user_id, except_device_id=None): + async def delete_all_devices_for_user( + self, user_id: str, except_device_id: Optional[str] = None + ) -> None: """Delete all of the user's devices Args: - user_id (str): - except_device_id (str|None): optional device id which should not - be deleted - - Returns: - defer.Deferred: + user_id: The user to remove all devices from + except_device_id: optional device id which should not be deleted """ - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) device_ids = list(device_map) if except_device_id is not None: device_ids = [d for d in device_ids if d != except_device_id] - yield self.delete_devices(user_id, device_ids) + await self.delete_devices(user_id, device_ids) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """ Delete several devices Args: - user_id (str): - device_ids (List[str]): The list of device IDs to delete - - Returns: - defer.Deferred: + user_id: The user to delete devices from. + device_ids: The list of device IDs to delete """ try: - yield self.store.delete_devices(user_id, device_ids) + await self.store.delete_devices(user_id, device_ids) except errors.StoreError as e: if e.code == 404: # no match @@ -399,28 +379,22 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( + await self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_ids) + await self.notify_device_update(user_id, device_ids) - @defer.inlineCallbacks - def update_device(self, user_id, device_id, content): + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """ Update the given device Args: - user_id (str): - device_id (str): - content (dict): body of update request - - Returns: - defer.Deferred: + user_id: The user to update devices of. + device_id: The device to update. + content: body of update request """ # Reject a new displayname which is too long. @@ -433,10 +407,10 @@ class DeviceHandler(DeviceWorkerHandler): ) try: - yield self.store.update_device( + await self.store.update_device( user_id, device_id, new_display_name=new_display_name ) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() @@ -445,12 +419,15 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") - @defer.inlineCallbacks - def notify_device_update(self, user_id, device_ids): + async def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + if not device_ids: + # No changes to notify about, so this is a no-op. + return + + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -461,20 +438,24 @@ class DeviceHandler(DeviceWorkerHandler): set_tag("target_hosts", hosts) - position = yield self.store.add_device_change_to_streams( + position = await self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) + if not position: + # This should only happen if there are no updates, so we bail. + return + for device_id in device_ids: logger.debug( "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - yield self.notifier.on_new_event( + self.notifier.on_new_event( "device_list_key", position, users=[user_id], rooms=room_ids ) @@ -486,29 +467,29 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender.send_device_messages(host) log_kv({"message": "sent device update to host", "host": host}) - @defer.inlineCallbacks - def notify_user_signature_update(self, from_user_id, user_ids): + async def notify_user_signature_update( + self, from_user_id: str, user_ids: List[str] + ) -> None: """Notify a user that they have made new signatures of other users. Args: - from_user_id (str): the user who made the signature - user_ids (list[str]): the users IDs that have new signatures + from_user_id: the user who made the signature + user_ids: the users IDs that have new signatures """ - position = yield self.store.add_user_signature_change_to_streams( + position = await self.store.add_user_signature_change_to_streams( from_user_id, user_ids ) self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - @defer.inlineCallbacks - def user_left_room(self, user, room_id): + async def user_left_room(self, user, room_id): user_id = user.to_string() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We no longer share rooms with this user, so we'll no longer # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) def _update_device_from_client_ips(device, client_ips): @@ -516,7 +497,7 @@ def _update_device_from_client_ips(device, client_ips): device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) -class DeviceListUpdater(object): +class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs, device_handler): @@ -551,8 +532,7 @@ class DeviceListUpdater(object): ) @trace - @defer.inlineCallbacks - def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -585,7 +565,7 @@ class DeviceListUpdater(object): ) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -610,14 +590,13 @@ class DeviceListUpdater(object): (device_id, stream_id, prev_ids, edu_content) ) - yield self._handle_device_updates(user_id) + await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - @defer.inlineCallbacks - def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id): "Actually handle pending updates." - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -634,7 +613,7 @@ class DeviceListUpdater(object): # Given a list of updates we check if we need to resync. This # happens if we've missed updates. - resync = yield self._need_to_do_resync(user_id, pending_updates) + resync = await self._need_to_do_resync(user_id, pending_updates) if logger.isEnabledFor(logging.INFO): logger.info( @@ -645,16 +624,16 @@ class DeviceListUpdater(object): ) if resync: - yield self.user_device_resync(user_id) + await self.user_device_resync(user_id) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: - yield self.store.update_remote_device_list_cache_entry( + await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates] ) @@ -662,14 +641,13 @@ class DeviceListUpdater(object): stream_id for _, stream_id, _, _ in pending_updates ) - @defer.inlineCallbacks - def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) + extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) logger.debug("Current extremity for %r: %r", user_id, extremity) @@ -693,8 +671,8 @@ class DeviceListUpdater(object): return False - @defer.inlineCallbacks - def _maybe_retry_device_resync(self): + @trace + async def _maybe_retry_device_resync(self): """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -706,12 +684,12 @@ class DeviceListUpdater(object): # we don't send too many requests. self._resync_retry_in_progress = True # Get all of the users that need resyncing. - need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + need_resync = await self.store.get_user_ids_requiring_device_list_resync() # Iterate over the set of user IDs. for user_id in need_resync: try: # Try to resync the current user's devices list. - result = yield self.user_device_resync( + result = await self.user_device_resync( user_id=user_id, mark_failed_as_stale=False, ) @@ -735,16 +713,17 @@ class DeviceListUpdater(object): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False - @defer.inlineCallbacks - def user_device_resync(self, user_id, mark_failed_as_stale=True): + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[dict]: """Fetches all devices for a user and updates the device cache with them. Args: - user_id (str): The user's id whose device_list will be updated. - mark_failed_as_stale (bool): Whether to mark the user's device list as stale + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale if the attempt to resync failed. Returns: - Deferred[dict]: a dict with device info as under the "devices" in the result of this + A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ @@ -753,12 +732,12 @@ class DeviceListUpdater(object): # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: - result = yield self.federation.query_user_devices(origin, user_id) + result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return except (RequestSendFailed, HttpResponseException) as e: @@ -769,7 +748,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list @@ -793,7 +772,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return log_kv({"result": result}) @@ -834,25 +813,24 @@ class DeviceListUpdater(object): stream_id, ) - yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) + await self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. - cross_signing_device_ids = yield self.process_cross_signing_key_update( + cross_signing_device_ids = await self.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + cross_signing_device_ids - yield self.device_handler.notify_device_update(user_id, device_ids) + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. self._seen_updates[user_id] = {stream_id} - defer.returnValue(result) + return result - @defer.inlineCallbacks - def process_cross_signing_key_update( + async def process_cross_signing_key_update( self, user_id: str, master_key: Optional[Dict[str, Any]], @@ -873,14 +851,14 @@ class DeviceListUpdater(object): device_ids = [] if master_key: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) if self_signing_key: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 05c4b3eec0..64ef7f63ab 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,10 +16,6 @@ import logging from typing import Any, Dict -from canonicaljson import json - -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -29,12 +25,13 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.types import UserID, get_domain_from_id +from synapse.util import json_encoder from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class DeviceMessageHandler(object): +class DeviceMessageHandler: def __init__(self, hs): """ Args: @@ -51,8 +48,7 @@ class DeviceMessageHandler(object): self._device_list_updater = hs.get_device_handler().device_list_updater - @defer.inlineCallbacks - def on_direct_to_device_edu(self, origin, content): + async def on_direct_to_device_edu(self, origin, content): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): @@ -82,11 +78,11 @@ class DeviceMessageHandler(object): } local_messages[user_id] = messages_by_device - yield self._check_for_unknown_devices( + await self._check_for_unknown_devices( message_type, sender_user_id, by_device ) - stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + stream_id = await self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages ) @@ -94,14 +90,13 @@ class DeviceMessageHandler(object): "to_device_key", stream_id, users=local_messages.keys() ) - @defer.inlineCallbacks - def _check_for_unknown_devices( + async def _check_for_unknown_devices( self, message_type: str, sender_user_id: str, by_device: Dict[str, Dict[str, Any]], ): - """Checks inbound device messages for unkown remote devices, and if + """Checks inbound device messages for unknown remote devices, and if found marks the remote cache for the user as stale. """ @@ -115,7 +110,7 @@ class DeviceMessageHandler(object): requesting_device_ids.add(device_id) # Check if we are tracking the devices of the remote user. - room_ids = yield self.store.get_rooms_for_user(sender_user_id) + room_ids = await self.store.get_rooms_for_user(sender_user_id) if not room_ids: logger.info( "Received device message from remote device we don't" @@ -127,7 +122,7 @@ class DeviceMessageHandler(object): # If we are tracking check that we know about the sending # devices. - cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) + cached_devices = await self.store.get_cached_devices_for_user(sender_user_id) unknown_devices = requesting_device_ids - set(cached_devices) if unknown_devices: @@ -136,15 +131,14 @@ class DeviceMessageHandler(object): sender_user_id, unknown_devices, ) - yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background run_in_background( self._device_list_updater.user_device_resync, sender_user_id ) - @defer.inlineCallbacks - def send_device_message(self, sender_user_id, message_type, messages): + async def send_device_message(self, sender_user_id, message_type, messages): set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} @@ -179,11 +173,11 @@ class DeviceMessageHandler(object): "sender": sender_user_id, "type": message_type, "message_id": message_id, - "org.matrix.opentracing_context": json.dumps(context), + "org.matrix.opentracing_context": json_encoder.encode(context), } log_kv({"local_messages": local_messages}) - stream_id = yield self.store.add_messages_to_device_inbox( + stream_id = await self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index f2f16b1e43..46826eb784 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -17,14 +17,13 @@ import logging import string from typing import Iterable, List, Optional -from twisted.internet import defer - from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( AuthError, CodeMessageException, Codes, NotFoundError, + ShadowBanError, StoreError, SynapseError, ) @@ -55,8 +54,7 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() - @defer.inlineCallbacks - def _create_association( + async def _create_association( self, room_alias: RoomAlias, room_id: str, @@ -76,13 +74,13 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) servers = {get_domain_from_id(u) for u in users} if not servers: raise SynapseError(400, "Failed to get server list") - yield self.store.create_room_alias_association( + await self.store.create_room_alias_association( room_alias, room_id, servers, creator=creator ) @@ -93,7 +91,7 @@ class DirectoryHandler(BaseHandler): room_id: str, servers: Optional[List[str]] = None, check_membership: bool = True, - ): + ) -> None: """Attempt to create a new alias Args: @@ -103,9 +101,6 @@ class DirectoryHandler(BaseHandler): servers: Iterable of servers that others servers should try and join via check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). - - Returns: - Deferred """ user_id = requester.user.to_string() @@ -148,7 +143,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = await self.can_modify_alias(room_alias, user_id=user_id) + can_create = self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -158,7 +153,9 @@ class DirectoryHandler(BaseHandler): await self._create_association(room_alias, room_id, servers, creator=user_id) - async def delete_association(self, requester: Requester, room_alias: RoomAlias): + async def delete_association( + self, requester: Requester, room_alias: RoomAlias + ) -> str: """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -169,7 +166,7 @@ class DirectoryHandler(BaseHandler): room_alias Returns: - Deferred[unicode]: room id that the alias used to point to + room id that the alias used to point to Raises: NotFoundError: if the alias doesn't exist @@ -191,7 +188,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = await self.can_modify_alias(room_alias, user_id=user_id) + can_delete = self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -203,13 +200,14 @@ class DirectoryHandler(BaseHandler): try: await self._update_canonical_alias(requester, user_id, room_id, room_alias) + except ShadowBanError as e: + logger.info("Failed to update alias events due to shadow-ban: %s", e) except AuthError as e: logger.info("Failed to update alias events: %s", e) return room_id - @defer.inlineCallbacks - def delete_appservice_association( + async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias ): if not service.is_interested_in_alias(room_alias.to_string()): @@ -218,29 +216,27 @@ class DirectoryHandler(BaseHandler): "This application service has not reserved this kind of alias", errcode=Codes.EXCLUSIVE, ) - yield self._delete_association(room_alias) + await self._delete_association(room_alias) - @defer.inlineCallbacks - def _delete_association(self, room_alias: RoomAlias): + async def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") - room_id = yield self.store.delete_room_alias(room_alias) + room_id = await self.store.delete_room_alias(room_alias) return room_id - @defer.inlineCallbacks - def get_association(self, room_alias: RoomAlias): + async def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.get_association_from_room_alias(room_alias) + result = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id servers = result.servers else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, @@ -265,7 +261,7 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) @@ -277,13 +273,12 @@ class DirectoryHandler(BaseHandler): return {"room_id": room_id, "servers": servers} - @defer.inlineCallbacks - def on_directory_query(self, args): + async def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room Alias is not hosted on this homeserver") - result = yield self.get_association_from_room_alias(room_alias) + result = await self.get_association_from_room_alias(room_alias) if result is not None: return {"room_id": result.room_id, "servers": result.servers} @@ -300,6 +295,9 @@ class DirectoryHandler(BaseHandler): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" @@ -344,16 +342,15 @@ class DirectoryHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias: RoomAlias): - result = yield self.store.get_association_from_room_alias(room_alias) + async def get_association_from_room_alias(self, room_alias: RoomAlias): + result = await self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists as_handler = self.appservice_handler - result = yield as_handler.query_room_alias_exists(room_alias) + result = await as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool: # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -366,12 +363,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - return defer.succeed(True) + return True elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - return defer.succeed(False) + return False # either no interested services, or no service with an exclusive lock - return defer.succeed(True) + return True async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): """Determine whether a user can delete an alias. @@ -459,8 +456,7 @@ class DirectoryHandler(BaseHandler): await self.store.set_room_is_public(room_id, making_public) - @defer.inlineCallbacks - def edit_published_appservice_room_list( + async def edit_published_appservice_room_list( self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public @@ -475,7 +471,7 @@ class DirectoryHandler(BaseHandler): if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") - yield self.store.set_room_is_public_appservice( + await self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 774a252619..d629c7c16c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,12 +16,11 @@ # limitations under the License. import logging - -from six import iteritems +from typing import Dict, List, Optional, Tuple import attr -from canonicaljson import encode_canonical_json, json -from signedjson.key import decode_verify_key_bytes +from canonicaljson import encode_canonical_json +from signedjson.key import VerifyKey, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 @@ -36,7 +35,7 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -44,7 +43,7 @@ from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) -class E2eKeysHandler(object): +class E2eKeysHandler: def __init__(self, hs): self.store = hs.get_datastore() self.federation = hs.get_federation_client() @@ -79,8 +78,7 @@ class E2eKeysHandler(object): ) @trace - @defer.inlineCallbacks - def query_devices(self, query_body, timeout, from_user_id): + async def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -126,7 +124,7 @@ class E2eKeysHandler(object): failures = {} results = {} if local_query: - local_result = yield self.query_local_devices(local_query) + local_result = await self.query_local_devices(local_query) for user_id, keys in local_result.items(): if user_id in local_query: results[user_id] = keys @@ -135,7 +133,7 @@ class E2eKeysHandler(object): remote_queries_not_in_cache = {} if remote_queries: query_list = [] - for user_id, device_ids in iteritems(remote_queries): + for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) else: @@ -144,10 +142,10 @@ class E2eKeysHandler(object): ( user_ids_not_in_cache, remote_results, - ) = yield self.store.get_user_devices_from_cache(query_list) - for user_id, devices in iteritems(remote_results): + ) = await self.store.get_user_devices_from_cache(query_list) + for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) - for device_id, device in iteritems(devices): + for device_id, device in devices.items(): keys = device.get("keys", None) device_display_name = device.get("device_display_name", None) if keys: @@ -163,14 +161,13 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) # Now fetch any devices that we don't have in our cache @trace - @defer.inlineCallbacks - def do_remote_query(destination): + async def do_remote_query(destination): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -194,7 +191,7 @@ class E2eKeysHandler(object): if device_list: continue - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: continue @@ -203,11 +200,11 @@ class E2eKeysHandler(object): # done an initial sync on the device list so we do it now. try: if self._is_master: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_devices = await self.device_handler.device_list_updater.user_device_resync( user_id ) else: - user_devices = yield self._user_device_resync_client( + user_devices = await self._user_device_resync_client( user_id=user_id ) @@ -229,7 +226,7 @@ class E2eKeysHandler(object): destination_query.pop(user_id) try: - remote_result = yield self.federation.query_client_keys( + remote_result = await self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout ) @@ -253,7 +250,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(do_remote_query, destination) @@ -269,8 +266,9 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks - def get_cross_signing_keys_from_cache(self, query, from_user_id): + async def get_cross_signing_keys_from_cache( + self, query, from_user_id + ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: @@ -282,8 +280,7 @@ class E2eKeysHandler(object): can see. Returns: - defer.Deferred[dict[str, dict[str, dict]]]: map from - (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key + A map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key """ master_keys = {} self_signing_keys = {} @@ -291,7 +288,7 @@ class E2eKeysHandler(object): user_ids = list(query) - keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) for user_id, user_info in keys.items(): if user_info is None: @@ -317,17 +314,17 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def query_local_devices(self, query): + async def query_local_devices( + self, query: Dict[str, Optional[List[str]]] + ) -> Dict[str, Dict[str, dict]]: """Get E2E device keys for local users Args: - query (dict[string, list[string]|None): map from user_id to a list + query: map from user_id to a list of devices to query (None for all devices) Returns: - defer.Deferred: (resolves to dict[string, dict[string, dict]]): - map from user_id -> device_id -> device details + A map from user_id -> device_id -> device details """ set_tag("local_query", query) local_query = [] @@ -356,7 +353,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = yield self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -366,16 +363,15 @@ class E2eKeysHandler(object): log_kv(results) return result_dict - @defer.inlineCallbacks - def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys(self, query_body): """ Handle a device key query from a federated server """ device_keys_query = query_body.get("device_keys", {}) - res = yield self.query_local_devices(device_keys_query) + res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) @@ -384,8 +380,7 @@ class E2eKeysHandler(object): return ret @trace - @defer.inlineCallbacks - def claim_one_time_keys(self, query, timeout): + async def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} @@ -401,7 +396,7 @@ class E2eKeysHandler(object): set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) - results = yield self.store.claim_e2e_one_time_keys(local_query) + results = await self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} @@ -409,16 +404,15 @@ class E2eKeysHandler(object): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json_decoder.decode(json_bytes) } @trace - @defer.inlineCallbacks - def claim_client_keys(destination): + async def claim_client_keys(destination): set_tag("destination", destination) device_keys = remote_queries[destination] try: - remote_result = yield self.federation.claim_client_keys( + remote_result = await self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): @@ -431,7 +425,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) @@ -446,9 +440,9 @@ class E2eKeysHandler(object): ",".join( ( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) + for user_id, user_keys in json_result.items() + for device_id, device_keys in user_keys.items() + for key_id, _ in device_keys.items() ) ), ) @@ -456,9 +450,8 @@ class E2eKeysHandler(object): log_kv({"one_time_keys": json_result, "failures": failures}) return {"one_time_keys": json_result, "failures": failures} - @defer.inlineCallbacks @tag_args - def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user(self, user_id, device_id, keys): time_now = self.clock.time_msec() @@ -479,12 +472,12 @@ class E2eKeysHandler(object): } ) # TODO: Sign the JSON with the server key - changed = yield self.store.set_e2e_device_keys( + changed = await self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed - yield self.device_handler.notify_device_update(user_id, [device_id]) + await self.device_handler.notify_device_update(user_id, [device_id]) else: log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) one_time_keys = keys.get("one_time_keys", None) @@ -496,7 +489,7 @@ class E2eKeysHandler(object): "device_id": device_id, } ) - yield self._upload_one_time_keys_for_user( + await self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) else: @@ -509,15 +502,14 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - yield self.device_handler.check_device_registered(user_id, device_id) + await self.device_handler.check_device_registered(user_id, device_id) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) set_tag("one_time_key_counts", result) return {"one_time_key_counts": result} - @defer.inlineCallbacks - def _upload_one_time_keys_for_user( + async def _upload_one_time_keys_for_user( self, user_id, device_id, time_now, one_time_keys ): logger.info( @@ -535,7 +527,7 @@ class E2eKeysHandler(object): key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. - existing_key_map = yield self.store.get_e2e_one_time_keys( + existing_key_map = await self.store.get_e2e_one_time_keys( user_id, device_id, [k_id for _, k_id, _ in key_list] ) @@ -558,10 +550,9 @@ class E2eKeysHandler(object): ) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) - yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - @defer.inlineCallbacks - def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user(self, user_id, keys): """Upload signing keys for cross-signing Args: @@ -576,7 +567,7 @@ class E2eKeysHandler(object): _check_cross_signing_key(master_key, user_id, "master") else: - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") # if there is no master key, then we can't do anything, because all the # other cross-signing keys need to be signed by the master key @@ -615,10 +606,10 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) try: @@ -628,23 +619,22 @@ class E2eKeysHandler(object): except ValueError: raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of # their own user-signing key updates - yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + await self.device_handler.notify_user_signature_update(user_id, [user_id]) # master key and self-signing key updates match the semantics of device # list updates: all users who share an encrypted room are notified if len(deviceids): - yield self.device_handler.notify_device_update(user_id, deviceids) + await self.device_handler.notify_device_update(user_id, deviceids) return {} - @defer.inlineCallbacks - def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys(self, user_id, signatures): """Upload device signatures for cross-signing Args: @@ -669,13 +659,13 @@ class E2eKeysHandler(object): self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - self_signature_list, self_failures = yield self._process_self_signatures( + self_signature_list, self_failures = await self._process_self_signatures( user_id, self_signatures ) signature_list.extend(self_signature_list) failures.update(self_failures) - other_signature_list, other_failures = yield self._process_other_signatures( + other_signature_list, other_failures = await self._process_other_signatures( user_id, other_signatures ) signature_list.extend(other_signature_list) @@ -683,21 +673,20 @@ class E2eKeysHandler(object): # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + await self.store.store_e2e_cross_signing_signatures(user_id, signature_list) self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: - yield self.device_handler.notify_device_update(user_id, self_device_ids) + await self.device_handler.notify_device_update(user_id, self_device_ids) signed_users = [item.target_user_id for item in other_signature_list] if signed_users: - yield self.device_handler.notify_user_signature_update( + await self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} - @defer.inlineCallbacks - def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -730,7 +719,7 @@ class E2eKeysHandler(object): _, self_signing_key_id, self_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so @@ -740,12 +729,12 @@ class E2eKeysHandler(object): master_key, _, master_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") @@ -855,8 +844,7 @@ class E2eKeysHandler(object): return master_key_signature_list - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures(self, user_id, signatures): """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. @@ -884,7 +872,7 @@ class E2eKeysHandler(object): user_signing_key, user_signing_key_id, user_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -907,7 +895,7 @@ class E2eKeysHandler(object): master_key, master_key_id, _, - ) = yield self._get_e2e_cross_signing_verify_key( + ) = await self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -960,8 +948,7 @@ class E2eKeysHandler(object): return signature_list, failures - @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key( + async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None ): """Fetch locally or remotely query for a cross-signing public key. @@ -985,7 +972,7 @@ class E2eKeysHandler(object): SynapseError: if `user_id` is invalid """ user = UserID.from_string(user_id) - key = yield self.store.get_e2e_cross_signing_key( + key = await self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) @@ -1011,17 +998,16 @@ class E2eKeysHandler(object): key, key_id, verify_key, - ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) if key is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) return key, key_id, verify_key - @defer.inlineCallbacks - def _retrieve_cross_signing_keys_for_remote_user( + async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, - ): + ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1032,12 +1018,11 @@ class E2eKeysHandler(object): desired_key_type: The type of key to receive. One of "master", "self_signing" Returns: - Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple - of the retrieved key content, the key's ID and the matching VerifyKey. + A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ try: - remote_result = yield self.federation.query_user_devices( + remote_result = await self.federation.query_user_devices( user.domain, user.to_string() ) except Exception as e: @@ -1103,14 +1088,14 @@ class E2eKeysHandler(object): desired_key_id = key_id # At the same time, store this key in the db for subsequent queries - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user.to_string(), key_type, key_content ) # Notify clients that new devices for this user have been discovered if retrieved_device_ids: # XXX is this necessary? - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user.to_string(), retrieved_device_ids ) @@ -1201,7 +1186,7 @@ def _exception_to_failure(e): def _one_time_keys_match(old_key_json, new_key): - old_key = json.loads(old_key_json) + old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly if not isinstance(old_key, dict) or not isinstance(new_key, dict): @@ -1227,7 +1212,7 @@ class SignatureListItem: signature = attr.ib() -class SigningKeyEduUpdater(object): +class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs, e2e_keys_handler): @@ -1252,8 +1237,7 @@ class SigningKeyEduUpdater(object): iterable=True, ) - @defer.inlineCallbacks - def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update(self, origin, edu_content): """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. @@ -1270,7 +1254,7 @@ class SigningKeyEduUpdater(object): logger.warning("Got signing key update edu for %r from %r", user_id, origin) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -1280,10 +1264,9 @@ class SigningKeyEduUpdater(object): (master_key, self_signing_key) ) - yield self._handle_signing_key_updates(user_id) + await self._handle_signing_key_updates(user_id) - @defer.inlineCallbacks - def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id): """Actually handle pending updates. Args: @@ -1293,7 +1276,7 @@ class SigningKeyEduUpdater(object): device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -1304,9 +1287,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = yield device_list_updater.process_cross_signing_key_update( + new_device_ids = await device_list_updater.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + new_device_ids - yield device_handler.notify_device_update(user_id, device_ids) + await device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 9abaf13b8f..f01b090772 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -16,10 +16,6 @@ import logging -from six import iteritems - -from twisted.internet import defer - from synapse.api.errors import ( Codes, NotFoundError, @@ -33,7 +29,7 @@ from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) -class E2eRoomKeysHandler(object): +class E2eRoomKeysHandler: """ Implements an optional realtime backup mechanism for encrypted E2E megolm room keys. This gives a way for users to store and recover their megolm keys if they lose all @@ -52,8 +48,7 @@ class E2eRoomKeysHandler(object): self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - @defer.inlineCallbacks - def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -73,17 +68,17 @@ class E2eRoomKeysHandler(object): # we deliberately take the lock to get keys so that changing the version # works atomically - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - yield self.store.get_e2e_room_keys_version_info(user_id, version) + await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - results = yield self.store.get_e2e_room_keys( + results = await self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) @@ -91,8 +86,7 @@ class E2eRoomKeysHandler(object): return results @trace - @defer.inlineCallbacks - def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. @@ -111,10 +105,10 @@ class E2eRoomKeysHandler(object): """ # lock for consistency with uploading - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -123,19 +117,18 @@ class E2eRoomKeysHandler(object): else: raise - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) version_etag = version_info["etag"] + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @trace - @defer.inlineCallbacks - def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys(self, user_id, version, room_keys): """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). @@ -171,11 +164,11 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version try: - version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + version_info = await self.store.get_e2e_room_keys_version_info(user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -185,7 +178,7 @@ class E2eRoomKeysHandler(object): if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) # if we get this far, the version must exist @@ -200,13 +193,13 @@ class E2eRoomKeysHandler(object): # submitted. Then compare them with the submitted keys. If the # key is new, insert it; if the key should be updated, then update # it; otherwise, drop it. - existing_keys = yield self.store.get_e2e_room_keys_multi( + existing_keys = await self.store.get_e2e_room_keys_multi( user_id, version, room_keys["rooms"] ) to_insert = [] # batch the inserts together changed = False # if anything has changed, we need to update the etag - for room_id, room in iteritems(room_keys["rooms"]): - for session_id, room_key in iteritems(room["sessions"]): + for room_id, room in room_keys["rooms"].items(): + for session_id, room_key in room["sessions"].items(): if not isinstance(room_key["is_verified"], bool): msg = ( "is_verified must be a boolean in keys for session %s in" @@ -229,7 +222,7 @@ class E2eRoomKeysHandler(object): # updates are done one at a time in the DB, so send # updates right away rather than batching them up, # like we do with the inserts - yield self.store.update_e2e_room_key( + await self.store.update_e2e_room_key( user_id, version, room_id, session_id, room_key ) changed = True @@ -248,16 +241,16 @@ class E2eRoomKeysHandler(object): changed = True if len(to_insert): - yield self.store.add_e2e_room_keys(user_id, version, to_insert) + await self.store.add_e2e_room_keys(user_id, version, to_insert) version_etag = version_info["etag"] if changed: version_etag = version_etag + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @staticmethod @@ -293,8 +286,7 @@ class E2eRoomKeysHandler(object): return True @trace - @defer.inlineCallbacks - def create_version(self, user_id, version_info): + async def create_version(self, user_id, version_info): """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. @@ -315,14 +307,13 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_keys_version( + with (await self._upload_linearizer.queue(user_id)): + new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) return new_version - @defer.inlineCallbacks - def get_version_info(self, user_id, version=None): + async def get_version_info(self, user_id, version=None): """Get the info about a given version of the user's backup Args: @@ -341,21 +332,21 @@ class E2eRoomKeysHandler(object): } """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"]) + res["etag"] = str(res["etag"]) return res @trace - @defer.inlineCallbacks - def delete_version(self, user_id, version=None): + async def delete_version(self, user_id, version=None): """Deletes a given version of the user's e2e_room_keys backup Args: @@ -365,9 +356,9 @@ class E2eRoomKeysHandler(object): NotFoundError: if this backup version doesn't exist """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - yield self.store.delete_e2e_room_keys_version(user_id, version) + await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") @@ -375,8 +366,7 @@ class E2eRoomKeysHandler(object): raise @trace - @defer.inlineCallbacks - def update_version(self, user_id, version, version_info): + async def update_version(self, user_id, version, version_info): """Update the info about a given version of the user's backup Args: @@ -394,9 +384,9 @@ class E2eRoomKeysHandler(object): raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - old_info = yield self.store.get_e2e_room_keys_version_info( + old_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -407,7 +397,7 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 71a89f09c7..b05e32f457 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -15,29 +15,30 @@ import logging import random +from typing import TYPE_CHECKING, Iterable, List, Optional from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function -from synapse.types import UserID +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(EventStreamHandler, self).__init__(hs) - # Count of active streams per user - self._streams_per_user = {} - # Grace timers per user to delay the "stopped" signal - self._stop_timer_per_user = {} - self.distributor = hs.get_distributor() self.distributor.declare("started_user_eventstream") self.distributor.declare("stopped_user_eventstream") @@ -52,18 +53,15 @@ class EventStreamHandler(BaseHandler): @log_function async def get_stream( self, - auth_user_id, - pagin_config, - timeout=0, - as_client_event=True, - affect_presence=True, - only_keys=None, - room_id=None, - is_guest=False, - ): + auth_user_id: str, + pagin_config: PaginationConfig, + timeout: int = 0, + as_client_event: bool = True, + affect_presence: bool = True, + room_id: Optional[str] = None, + is_guest: bool = False, + ) -> JsonDict: """Fetches the events stream for a given user. - - If `only_keys` is not None, events from keys will be sent down. """ if room_id: @@ -93,7 +91,6 @@ class EventStreamHandler(BaseHandler): auth_user, pagin_config, timeout, - only_keys=only_keys, is_guest=is_guest, explicit_room_id=room_id, ) @@ -102,7 +99,7 @@ class EventStreamHandler(BaseHandler): # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add = [] + to_add = [] # type: List[JsonDict] for event in events: if not isinstance(event, EventBase): continue @@ -114,7 +111,7 @@ class EventStreamHandler(BaseHandler): # Send down presence for everyone in the room. users = await self.state.get_current_users_in_room( event.room_id - ) + ) # type: Iterable[str] else: users = [event.state_key] @@ -148,20 +145,22 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - async def get_event(self, user, room_id, event_id): + async def get_event( + self, user: UserID, room_id: Optional[str], event_id: str + ) -> Optional[EventBase]: """Retrieve a single specified event. Args: - user (synapse.types.UserID): The user requesting the event - room_id (str|None): The expected room id. We'll return None if the + user: The user requesting the event + room_id: The expected room id. We'll return None if the event's room does not match. - event_id (str): The event ID to obtain. + event_id: The event ID to obtain. Returns: - dict: An event, or None if there is no event matching this ID. + An event, or None if there is no event matching this ID. Raises: SynapseError if there was a problem retrieving this event, or AuthError if the user does not have the rights to inspect this diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3e60774b33..43f2986f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,11 +19,9 @@ import itertools import logging -from typing import Dict, Iterable, List, Optional, Sequence, Tuple - -import six -from six import iteritems, itervalues -from six.moves import http_client, zip +from collections.abc import Container +from http import HTTPStatus +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import attr from signedjson.key import decode_verify_key_bytes @@ -33,7 +31,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.api.constants import ( + EventTypes, + Membership, + RejectedReason, + RoomEncryptionAlgorithms, +) from synapse.api.errors import ( AuthError, CodeMessageException, @@ -41,6 +44,7 @@ from synapse.api.errors import ( FederationDeniedError, FederationError, HttpResponseException, + NotFoundError, RequestSendFailed, SynapseError, ) @@ -58,6 +62,7 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, @@ -66,8 +71,14 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -91,7 +102,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) + auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) class FederationHandler(BaseHandler): @@ -238,7 +249,7 @@ class FederationHandler(BaseHandler): logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_seen_events(prevs) + seen = await self.store.have_events_in_timeline(prevs) if min_depth is not None and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -278,7 +289,7 @@ class FederationHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - seen = await self.store.have_seen_events(prevs) + seen = await self.store.have_events_in_timeline(prevs) if not prevs - seen: logger.info( @@ -374,6 +385,7 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) state_map = await resolve_events_with_store( + self.clock, room_id, room_version, state_maps, @@ -393,7 +405,7 @@ class FederationHandler(BaseHandler): ) event_map.update(evs) - state = [event_map[e] for e in six.itervalues(state_map)] + state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -423,16 +435,16 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = await self.store.have_seen_events(prevs) + seen = await self.store.have_events_in_timeline(prevs) if not prevs - seen: return - latest = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest) + latest = set(latest_list) latest |= seen logger.info( @@ -614,6 +626,11 @@ class FederationHandler(BaseHandler): will be omitted from the result. Likewise, any events which turn out not to be in the given room. + This function *does not* automatically get missing auth events of the + newly fetched events. Callers must include the full auth chain of + of the missing events in the `event_ids` argument, to ensure that any + missing auth events are correctly fetched. + Returns: map from event_id to event """ @@ -739,10 +756,16 @@ class FederationHandler(BaseHandler): # device and recognize the algorithm then we can work out the # exact key to expect. Otherwise check it matches any key we # have for that device. + + current_keys = [] # type: Container[str] + if device: keys = device.get("keys", {}).get("keys", {}) - if event.content.get("algorithm") == "m.megolm.v1.aes-sha2": + if ( + event.content.get("algorithm") + == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 + ): # For this algorithm we expect a curve25519 key. key_name = "curve25519:%s" % (device_id,) current_keys = [keys.get(key_name)] @@ -752,15 +775,15 @@ class FederationHandler(BaseHandler): current_keys = keys.values() elif device_id: # We don't have any keys for the device ID. - current_keys = [] + pass else: # The event didn't include a device ID, so we just look for # keys across all devices. - current_keys = ( + current_keys = [ key - for device in cached_devices + for device in cached_devices.values() for key in device.get("keys", {}).get("keys", {}).values() - ) + ] # We now check that the sender key matches (one of) the expected # keys. @@ -774,15 +797,25 @@ class FederationHandler(BaseHandler): resync = True if resync: - await self.store.mark_remote_user_device_cache_as_stale(event.sender) + run_as_background_process( + "resync_device_due_to_pdu", self._resync_device, event.sender + ) - # Immediately attempt a resync in the background - if self.config.worker_app: - return run_in_background(self._user_device_resync, event.sender) - else: - return run_in_background( - self._device_list_updater.user_device_resync, event.sender - ) + async def _resync_device(self, sender: str) -> None: + """We have detected that the device list for the given user may be out + of sync, so we try and resync them. + """ + + try: + await self.store.mark_remote_user_device_cache_as_stale(sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + await self._user_device_resync(user_id=sender) + else: + await self._device_list_updater.user_device_resync(sender) + except Exception: + logger.exception("Failed to resync device for %s", sender) @log_function async def backfill(self, dest, room_id, limit, extremities): @@ -1001,11 +1034,11 @@ class FederationHandler(BaseHandler): """ joined_users = [ (state_key, int(event.depth)) - for (e_type, state_key), event in iteritems(state) + for (e_type, state_key), event in state.items() if e_type == EventTypes.Member and event.membership == Membership.JOIN ] - joined_domains = {} + joined_domains = {} # type: Dict[str, int] for u, d in joined_users: try: dom = get_domain_from_id(u) @@ -1091,16 +1124,16 @@ class FederationHandler(BaseHandler): states = dict(zip(event_ids, [s.state for s in states])) state_map = await self.store.get_events( - [e_id for ids in itervalues(states) for e_id in itervalues(ids)], + [e_id for ids in states.values() for e_id in ids.values()], get_prev_content=False, ) states = { key: { k: state_map[e_id] - for k, e_id in iteritems(state_dict) + for k, e_id in state_dict.items() if e_id in state_map } - for key, state_dict in iteritems(states) + for key, state_dict in states.items() } for e_id, _ in sorted_extremeties_tuple: @@ -1121,12 +1154,16 @@ class FederationHandler(BaseHandler): ): """Fetch the given events from a server, and persist them as outliers. + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `events` argument + any missing events from the auth chain. + Logs a warning if we can't find the given event. """ room_version = await self.store.get_room_version(room_id) - event_infos = [] + event_map = {} # type: Dict[str, EventBase] async def get_event(event_id: str): with nested_logging_context(event_id): @@ -1140,17 +1177,7 @@ class FederationHandler(BaseHandler): ) return - # recursively fetch the auth events for this event - auth_events = await self._get_events_from_store_or_dest( - destination, room_id, event.auth_event_ids() - ) - auth = {} - for auth_event_id in event.auth_event_ids(): - ae = auth_events.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - - event_infos.append(_NewEventInfo(event, None, auth)) + event_map[event.event_id] = event except Exception as e: logger.warning( @@ -1162,6 +1189,32 @@ class FederationHandler(BaseHandler): await concurrently_execute(get_event, events, 5) + # Make a map of auth events for each event. We do this after fetching + # all the events as some of the events' auth events will be in the list + # of requested events. + + auth_events = [ + aid + for event in event_map.values() + for aid in event.auth_event_ids() + if aid not in event_map + ] + persisted_events = await self.store.get_events( + auth_events, allow_rejected=True, + ) + + event_infos = [] + for event in event_map.values(): + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + event_infos.append(_NewEventInfo(event, None, auth)) + await self._handle_new_events( destination, event_infos, ) @@ -1188,7 +1241,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.prev_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: logger.warning( @@ -1196,7 +1249,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.auth_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -1271,14 +1324,15 @@ class FederationHandler(BaseHandler): try: # Try the host we successfully got a response to /make_join/ # request first. + host_list = list(target_hosts) try: - target_hosts.remove(origin) - target_hosts.insert(0, origin) + host_list.remove(origin) + host_list.insert(0, origin) except ValueError: pass ret = await self.federation_client.send_join( - target_hosts, event, room_version_obj + host_list, event, room_version_obj ) origin = ret["origin"] @@ -1346,7 +1400,7 @@ class FederationHandler(BaseHandler): # it's just a best-effort thing at this point. We do want to do # them roughly in order, though, otherwise we'll end up making # lots of requests for missing prev_events which we do actually - # have. Hence we fire off the deferred, but don't wait for it. + # have. Hence we fire off the background task, but don't wait for it. run_in_background(self._handle_queued_pdus, room_queue) @@ -1392,10 +1446,20 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - event_content = {"membership": Membership.JOIN} - + # checking the room version will check that we've actually heard of the room + # (and return a 404 otherwise) room_version = await self.store.get_room_version_id(room_id) + # now check that we are *still* in the room + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) + if not is_in_room: + logger.info( + "Got /make_join request for room %s we are no longer in", room_id, + ) + raise NotFoundError("Not an active room on this server") + + event_content = {"membership": Membership.JOIN} + builder = self.event_builder_factory.new( room_version, { @@ -1539,7 +1603,7 @@ class FederationHandler(BaseHandler): # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a @@ -1556,7 +1620,7 @@ class FederationHandler(BaseHandler): room_version, event.get_pdu_json(), self.hs.hostname, - self.hs.config.signing_key[0], + self.hs.signing_key, ) ) @@ -1578,13 +1642,14 @@ class FederationHandler(BaseHandler): # Try the host that we succesfully called /make_leave/ on first for # the /send_leave/ request. + host_list = list(target_hosts) try: - target_hosts.remove(origin) - target_hosts.insert(0, origin) + host_list.remove(origin) + host_list.insert(0, origin) except ValueError: pass - await self.federation_client.send_leave(target_hosts, event) + await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) stream_id = await self.persist_events_and_notify([(event, context)]) @@ -1598,7 +1663,7 @@ class FederationHandler(BaseHandler): user_id: str, membership: str, content: JsonDict = {}, - params: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Union[str, Iterable[str]]]] = None, ) -> Tuple[str, EventBase, RoomVersion]: ( origin, @@ -1718,14 +1783,12 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) if state_groups: - _, state = list(iteritems(state_groups)).pop() + _, state = list(state_groups.items()).pop() results = {(e.type, e.state_key): e for e in state} if event.is_state(): @@ -1746,9 +1809,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -1818,8 +1879,8 @@ class FederationHandler(BaseHandler): else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False @@ -1828,9 +1889,6 @@ class FederationHandler(BaseHandler): origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) - # reraise does not allow inlineCallbacks to preserve the stacktrace, so we - # hack around with a try/finally instead. - success = False try: if ( not event.internal_metadata.is_outlier() @@ -1844,12 +1902,11 @@ class FederationHandler(BaseHandler): await self.persist_events_and_notify( [(event, context)], backfilled=backfilled ) - success = True - finally: - if not success: - run_in_background( - self.store.remove_push_actions_from_staging, event.event_id - ) + except Exception: + run_in_background( + self.store.remove_push_actions_from_staging, event.event_id + ) + raise return context @@ -2002,18 +2059,18 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[StateMap[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], backfilled: bool, ) -> EventContext: context = await self.state_handler.compute_event_context(event, old_state=state) if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -2049,76 +2106,71 @@ class FederationHandler(BaseHandler): # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() - if do_soft_fail_check: - extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) - - extrem_ids = set(extrem_ids) - prev_event_ids = set(event.prev_event_ids()) - - if extrem_ids == prev_event_ids: - # If they're the same then the current state is the same as the - # state at the event, so no point rechecking auth for soft fail. - do_soft_fail_check = False - - if do_soft_fail_check: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # Calculate the "current state". - if state is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. + if backfilled or event.internal_metadata.is_outlier(): + return - state_sets = await self.state_store.get_state_groups( - event.room_id, extrem_ids - ) - state_sets = list(state_sets.values()) - state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( - room_version, state_sets, event - ) - current_state_ids = { - k: e.event_id for k, e in iteritems(current_state_ids) - } - else: - current_state_ids = await self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids - ) + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) + prev_event_ids = set(event.prev_event_ids()) - logger.debug( - "Doing soft-fail check for %s: state %s", - event.event_id, - current_state_ids, + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets = list(state_sets.values()) + state_sets.append(state) + current_states = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids = { + k: e.event_id for k, e in current_states.items() + } # type: StateMap[str] + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids ) - # Now check if event pass auth against said current state - auth_types = auth_types_for_event(event) - current_state_ids = [ - e for k, e in iteritems(current_state_ids) if k in auth_types - ] + logger.debug( + "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + ) - current_auth_events = await self.store.get_events(current_state_ids) - current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() - } + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(event) + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] - try: - event_auth.check( - room_version_obj, event, auth_events=current_auth_events - ) - except AuthError as e: - logger.warning("Soft-failing %r because %s", event, e) - event.internal_metadata.soft_failed = True + auth_events_map = await self.store.get_events(current_state_ids_list) + current_auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning("Soft-failing %r because %s", event, e) + event.internal_metadata.soft_failed = True async def on_query_auth( self, origin, event_id, room_id, remote_auth_chain, rejects, missing @@ -2127,9 +2179,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. @@ -2181,7 +2231,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """ @@ -2232,7 +2282,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: StateMap[EventBase], + auth_events: MutableStateMap[EventBase], ) -> EventContext: """Helper for do_auth. See there for docs. @@ -2287,10 +2337,10 @@ class FederationHandler(BaseHandler): remote_auth_chain = await self.federation_client.get_event_auth( origin, event.room_id, event.event_id ) - except RequestSendFailed as e: + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e) + logger.info("Failed to get event auth from remote: %s", e1) return context seen_remotes = await self.store.have_seen_events( @@ -2420,18 +2470,18 @@ class FederationHandler(BaseHandler): else: event_key = None state_updates = { - k: a.event_id for k, a in iteritems(auth_events) if k != event_key + k: a.event_id for k, a in auth_events.items() if k != event_key } current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) + current_state_ids = dict(current_state_ids) # type: ignore current_state_ids.update(state_updates) prev_state_ids = await context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) # create a new state group as a delta from the existing one. prev_group = context.state_group @@ -2768,7 +2818,8 @@ class FederationHandler(BaseHandler): logger.debug("Checking auth on event %r", event.content) - last_exception = None + last_exception = None # type: Optional[Exception] + # for each public key in the 3pid invite event for public_key_object in self.hs.get_auth().get_public_keys(invite_event): try: @@ -2822,6 +2873,12 @@ class FederationHandler(BaseHandler): return except Exception as e: last_exception = e + + if last_exception is None: + # we can only get here if get_public_keys() returned an empty list + # TODO: make this better + raise RuntimeError("no public key in invite event") + raise last_exception async def _check_key_revocation(self, public_key, url): @@ -2937,7 +2994,9 @@ class FederationHandler(BaseHandler): else: user_joined_room(self.distributor, user, room_id) - async def get_room_complexity(self, remote_room_hosts, room_id): + async def get_room_complexity( + self, remote_room_hosts: List[str], room_id: str + ) -> Optional[dict]: """ Fetch the complexity of a remote room over federation. @@ -2946,7 +3005,7 @@ class FederationHandler(BaseHandler): room_id (str): The room ID to ask about. Returns: - Deferred[dict] or Deferred[None]: Dict contains the complexity + Dict contains the complexity metric versions, while None means we could not fetch the complexity. """ diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 7d62899824..ba47e3a242 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -16,8 +16,6 @@ import logging -from six import iteritems - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.types import get_domain_from_id @@ -25,43 +23,36 @@ logger = logging.getLogger(__name__) def _create_rerouter(func_name): - """Returns a function that looks at the group id and calls the function + """Returns an async function that looks at the group id and calls the function on federation or the local group server if the group is local """ - def f(self, group_id, *args, **kwargs): + async def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): - return getattr(self.groups_server_handler, func_name)( + return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs ) else: destination = get_domain_from_id(group_id) - d = getattr(self.transport_client, func_name)( - destination, group_id, *args, **kwargs - ) - # Capture errors returned by the remote homeserver and - # re-throw specific errors as SynapseErrors. This is so - # when the remote end responds with things like 403 Not - # In Group, we can communicate that to the client instead - # of a 500. - def http_response_errback(failure): - failure.trap(HttpResponseException) - e = failure.value + try: + return await getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + except HttpResponseException as e: + # Capture errors returned by the remote homeserver and + # re-throw specific errors as SynapseErrors. This is so + # when the remote end responds with things like 403 Not + # In Group, we can communicate that to the client instead + # of a 500. raise e.to_synapse_error() - - def request_failed_errback(failure): - failure.trap(RequestSendFailed) + except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") - d.addErrback(http_response_errback) - d.addErrback(request_failed_errback) - return d - return f -class GroupsLocalWorkerHandler(object): +class GroupsLocalWorkerHandler: def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() @@ -72,7 +63,7 @@ class GroupsLocalWorkerHandler(object): self.clock = hs.get_clock() self.keyring = hs.get_keyring() self.is_mine_id = hs.is_mine_id - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.server_name = hs.hostname self.notifier = hs.get_notifier() self.attestations = hs.get_groups_attestation_signing() @@ -227,7 +218,7 @@ class GroupsLocalWorkerHandler(object): results = {} failed_results = [] - for destination, dest_user_ids in iteritems(destinations): + for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( destination, list(dest_user_ids) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 4ba0042768..0ce6ddfbe4 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -19,16 +19,11 @@ import logging import urllib.parse - -from canonicaljson import json -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json -from unpaddedbase64 import decode_base64 +from typing import Awaitable, Callable, Dict, List, Optional, Tuple from twisted.internet.error import TimeoutError from synapse.api.errors import ( - AuthError, CodeMessageException, Codes, HttpResponseException, @@ -36,6 +31,8 @@ from synapse.api.errors import ( ) from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.client import SimpleHttpClient +from synapse.types import JsonDict, Requester +from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -59,23 +56,23 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_http_client() self.hs = hs - async def threepid_from_creds(self, id_server, creds): + async def threepid_from_creds( + self, id_server: str, creds: Dict[str, str] + ) -> Optional[JsonDict]: """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server Args: - id_server (str): The identity server to validate 3PIDs against. Must be a + id_server: The identity server to validate 3PIDs against. Must be a complete URL including the protocol (http(s)://) - - creds (dict[str, str]): Dictionary containing the following keys: + creds: Dictionary containing the following keys: * client_secret|clientSecret: A unique secret str provided by the client * sid: The ID of the validation session Returns: - Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to - the /getValidated3pid endpoint of the Identity Service API, or None if the - threepid was not found + A dictionary consisting of response params to the /getValidated3pid + endpoint of the Identity Service API, or None if the threepid was not found """ client_secret = creds.get("client_secret") or creds.get("clientSecret") if not client_secret: @@ -119,26 +116,27 @@ class IdentityHandler(BaseHandler): return None async def bind_threepid( - self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True - ): + self, + client_secret: str, + sid: str, + mxid: str, + id_server: str, + id_access_token: Optional[str] = None, + use_v2: bool = True, + ) -> JsonDict: """Bind a 3PID to an identity server Args: - client_secret (str): A unique secret provided by the client - - sid (str): The ID of the validation session - - mxid (str): The MXID to bind the 3PID to - - id_server (str): The domain of the identity server to query - - id_access_token (str): The access token to authenticate to the identity + client_secret: A unique secret provided by the client + sid: The ID of the validation session + mxid: The MXID to bind the 3PID to + id_server: The domain of the identity server to query + id_access_token: The access token to authenticate to the identity server with, if necessary. Required if use_v2 is true - - use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True + use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True Returns: - Deferred[dict]: The response from the identity server + The response from the identity server """ logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server) @@ -151,7 +149,7 @@ class IdentityHandler(BaseHandler): bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} if use_v2: bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) - headers["Authorization"] = create_id_access_token_header(id_access_token) + headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore else: bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) @@ -178,7 +176,7 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: - data = json.loads(e.msg) # XXX WAT? + data = json_decoder.decode(e.msg) # XXX WAT? return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) @@ -187,20 +185,20 @@ class IdentityHandler(BaseHandler): ) return res - async def try_unbind_threepid(self, mxid, threepid): + async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool: """Attempt to remove a 3PID from an identity server, or if one is not provided, all identity servers we're aware the binding is present on Args: - mxid (str): Matrix user ID of binding to be removed - threepid (dict): Dict with medium & address of binding to be + mxid: Matrix user ID of binding to be removed + threepid: Dict with medium & address of binding to be removed, and an optional id_server. Raises: SynapseError: If we failed to contact the identity server Returns: - Deferred[bool]: True on success, otherwise False if the identity + True on success, otherwise False if the identity server doesn't support unbinding (or no identity server found to contact). """ @@ -223,19 +221,21 @@ class IdentityHandler(BaseHandler): return changed - async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): + async def try_unbind_threepid_with_id_server( + self, mxid: str, threepid: dict, id_server: str + ) -> bool: """Removes a binding from an identity server Args: - mxid (str): Matrix user ID of binding to be removed - threepid (dict): Dict with medium & address of binding to be removed - id_server (str): Identity server to unbind from + mxid: Matrix user ID of binding to be removed + threepid: Dict with medium & address of binding to be removed + id_server: Identity server to unbind from Raises: SynapseError: If we failed to contact the identity server Returns: - Deferred[bool]: True on success, otherwise False if the identity + True on success, otherwise False if the identity server doesn't support unbinding """ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) @@ -251,10 +251,10 @@ class IdentityHandler(BaseHandler): # 'browser-like' HTTPS. auth_headers = self.federation_http_client.build_auth_headers( destination=None, - method="POST", + method=b"POST", url_bytes=url_bytes, content=content, - destination_is=id_server, + destination_is=id_server.encode("ascii"), ) headers = {b"Authorization": auth_headers} @@ -287,23 +287,23 @@ class IdentityHandler(BaseHandler): async def send_threepid_validation( self, - email_address, - client_secret, - send_attempt, - send_email_func, - next_link=None, - ): + email_address: str, + client_secret: str, + send_attempt: int, + send_email_func: Callable[[str, str, str, str], Awaitable], + next_link: Optional[str] = None, + ) -> str: """Send a threepid validation email for password reset or registration purposes Args: - email_address (str): The user's email address - client_secret (str): The provided client secret - send_attempt (int): Which send attempt this is - send_email_func (func): A function that takes an email address, token, - client_secret and session_id, sends an email - and returns a Deferred. - next_link (str|None): The URL to redirect the user to after validation + email_address: The user's email address + client_secret: The provided client secret + send_attempt: Which send attempt this is + send_email_func: A function that takes an email address, token, + client_secret and session_id, sends an email + and returns an Awaitable. + next_link: The URL to redirect the user to after validation Returns: The new session_id upon success @@ -372,17 +372,22 @@ class IdentityHandler(BaseHandler): return session_id async def requestEmailToken( - self, id_server, email, client_secret, send_attempt, next_link=None - ): + self, + id_server: str, + email: str, + client_secret: str, + send_attempt: int, + next_link: Optional[str] = None, + ) -> JsonDict: """ Request an external server send an email on our behalf for the purposes of threepid validation. Args: - id_server (str): The identity server to proxy to - email (str): The email to send the message to - client_secret (str): The unique client_secret sends by the user - send_attempt (int): Which attempt this is + id_server: The identity server to proxy to + email: The email to send the message to + client_secret: The unique client_secret sends by the user + send_attempt: Which attempt this is next_link: A link to redirect the user to once they submit the token Returns: @@ -419,22 +424,22 @@ class IdentityHandler(BaseHandler): async def requestMsisdnToken( self, - id_server, - country, - phone_number, - client_secret, - send_attempt, - next_link=None, - ): + id_server: str, + country: str, + phone_number: str, + client_secret: str, + send_attempt: int, + next_link: Optional[str] = None, + ) -> JsonDict: """ Request an external server send an SMS message on our behalf for the purposes of threepid validation. Args: - id_server (str): The identity server to proxy to - country (str): The country code of the phone number - phone_number (str): The number to send the message to - client_secret (str): The unique client_secret sends by the user - send_attempt (int): Which attempt this is + id_server: The identity server to proxy to + country: The country code of the phone number + phone_number: The number to send the message to + client_secret: The unique client_secret sends by the user + send_attempt: Which attempt this is next_link: A link to redirect the user to once they submit the token Returns: @@ -480,17 +485,18 @@ class IdentityHandler(BaseHandler): ) return data - async def validate_threepid_session(self, client_secret, sid): + async def validate_threepid_session( + self, client_secret: str, sid: str + ) -> Optional[JsonDict]: """Validates a threepid session with only the client secret and session ID Tries validating against any configured account_threepid_delegates as well as locally. Args: - client_secret (str): A secret provided by the client - - sid (str): The ID of the session + client_secret: A secret provided by the client + sid: The ID of the session Returns: - Dict[str, str|int] if validation was successful, otherwise None + The json response if validation was successful, otherwise None """ # XXX: We shouldn't need to keep wrapping and unwrapping this value threepid_creds = {"client_secret": client_secret, "sid": sid} @@ -523,23 +529,22 @@ class IdentityHandler(BaseHandler): return validation_session - async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token): + async def proxy_msisdn_submit_token( + self, id_server: str, client_secret: str, sid: str, token: str + ) -> JsonDict: """Proxy a POST submitToken request to an identity server for verification purposes Args: - id_server (str): The identity server URL to contact - - client_secret (str): Secret provided by the client - - sid (str): The ID of the session - - token (str): The verification token + id_server: The identity server URL to contact + client_secret: Secret provided by the client + sid: The ID of the session + token: The verification token Raises: SynapseError: If we failed to contact the identity server Returns: - Deferred[dict]: The response dict from the identity server + The response dict from the identity server """ body = {"client_secret": client_secret, "sid": sid, "token": token} @@ -554,19 +559,25 @@ class IdentityHandler(BaseHandler): logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) raise SynapseError(400, "Error contacting the identity server") - async def lookup_3pid(self, id_server, medium, address, id_access_token=None): + async def lookup_3pid( + self, + id_server: str, + medium: str, + address: str, + id_access_token: Optional[str] = None, + ) -> Optional[str]: """Looks up a 3pid in the passed identity server. Args: - id_server (str): The server name (including port, if required) + id_server: The server name (including port, if required) of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - id_access_token (str|None): The access token to authenticate to the identity + medium: The type of the third party identifier (e.g. "email"). + address: The third party identifier (e.g. "foo@example.com"). + id_access_token: The access token to authenticate to the identity server with Returns: - str|None: the matrix ID of the 3pid, or None if it is not recognized. + the matrix ID of the 3pid, or None if it is not recognized. """ if id_access_token is not None: try: @@ -591,17 +602,19 @@ class IdentityHandler(BaseHandler): return await self._lookup_3pid_v1(id_server, medium, address) - async def _lookup_3pid_v1(self, id_server, medium, address): + async def _lookup_3pid_v1( + self, id_server: str, medium: str, address: str + ) -> Optional[str]: """Looks up a 3pid in the passed identity server using v1 lookup. Args: - id_server (str): The server name (including port, if required) + id_server: The server name (including port, if required) of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). + medium: The type of the third party identifier (e.g. "email"). + address: The third party identifier (e.g. "foo@example.com"). Returns: - str: the matrix ID of the 3pid, or None if it is not recognized. + the matrix ID of the 3pid, or None if it is not recognized. """ try: data = await self.blacklisting_http_client.get_json( @@ -610,9 +623,9 @@ class IdentityHandler(BaseHandler): ) if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - await self._verify_any_signature(data, id_server) + # note: we used to verify the identity server's signature here, but no longer + # require or validate it. See the following for context: + # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950 return data["mxid"] except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") @@ -621,18 +634,20 @@ class IdentityHandler(BaseHandler): return None - async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address): + async def _lookup_3pid_v2( + self, id_server: str, id_access_token: str, medium: str, address: str + ) -> Optional[str]: """Looks up a 3pid in the passed identity server using v2 lookup. Args: - id_server (str): The server name (including port, if required) + id_server: The server name (including port, if required) of the identity server to use. - id_access_token (str): The access token to authenticate to the identity server with - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). + id_access_token: The access token to authenticate to the identity server with + medium: The type of the third party identifier (e.g. "email"). + address: The third party identifier (e.g. "foo@example.com"). Returns: - Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised. + the matrix ID of the 3pid, or None if it is not recognised. """ # Check what hashing details are supported by this identity server try: @@ -731,75 +746,50 @@ class IdentityHandler(BaseHandler): mxid = lookup_results["mappings"].get(lookup_value) return mxid - async def _verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - try: - key_data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" - % (id_server_scheme, server_hostname, key_name) - ) - except TimeoutError: - raise SynapseError(500, "Timed out contacting identity server") - if "public_key" not in key_data: - raise AuthError( - 401, "No public key named %s from %s" % (key_name, server_hostname) - ) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes( - key_name, decode_base64(key_data["public_key"]) - ), - ) - return - async def ask_id_server_for_third_party_invite( self, - requester, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url, - id_access_token=None, - ): + requester: Requester, + id_server: str, + medium: str, + address: str, + room_id: str, + inviter_user_id: str, + room_alias: str, + room_avatar_url: str, + room_join_rules: str, + room_name: str, + inviter_display_name: str, + inviter_avatar_url: str, + id_access_token: Optional[str] = None, + ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]: """ Asks an identity server for a third party invite. Args: - requester (Requester) - id_server (str): hostname + optional port for the identity server. - medium (str): The literal string "email". - address (str): The third party address being invited. - room_id (str): The ID of the room to which the user is invited. - inviter_user_id (str): The user ID of the inviter. - room_alias (str): An alias for the room, for cosmetic notifications. - room_avatar_url (str): The URL of the room's avatar, for cosmetic + requester + id_server: hostname + optional port for the identity server. + medium: The literal string "email". + address: The third party address being invited. + room_id: The ID of the room to which the user is invited. + inviter_user_id: The user ID of the inviter. + room_alias: An alias for the room, for cosmetic notifications. + room_avatar_url: The URL of the room's avatar, for cosmetic notifications. - room_join_rules (str): The join rules of the email (e.g. "public"). - room_name (str): The m.room.name of the room. - inviter_display_name (str): The current display name of the + room_join_rules: The join rules of the email (e.g. "public"). + room_name: The m.room.name of the room. + inviter_display_name: The current display name of the inviter. - inviter_avatar_url (str): The URL of the inviter's avatar. + inviter_avatar_url: The URL of the inviter's avatar. id_access_token (str|None): The access token to authenticate to the identity server with Returns: - A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. + A tuple containing: + token: The token which must be signed to prove authenticity. public_keys ([{"public_key": str, "key_validity_url": str}]): public_key is a base64-encoded ed25519 public key. fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. + display_name: A user-friendly name to represent the invited user. """ invite_config = { "medium": medium, @@ -896,15 +886,15 @@ class IdentityHandler(BaseHandler): return token, public_keys, fallback_public_key, display_name -def create_id_access_token_header(id_access_token): +def create_id_access_token_header(id_access_token: str) -> List[str]: """Create an Authorization header for passing to SimpleHttpClient as the header value of an HTTP request. Args: - id_access_token (str): An identity server access token. + id_access_token: An identity server access token. Returns: - list[str]: The ascii-encoded bearer token encased in a list. + The ascii-encoded bearer token encased in a list. """ # Prefix with Bearer bearer_token = "Bearer %s" % id_access_token diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f88bad5f25..d5ddc583ad 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from twisted.internet import defer @@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import StreamToken, UserID +from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class InitialSyncHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(InitialSyncHandler, self).__init__(hs) self.hs = hs self.state = hs.get_state_handler() @@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler): def snapshot_all_rooms( self, - user_id=None, - pagin_config=None, - as_client_event=True, - include_archived=False, - ): + user_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + include_archived: bool = False, + ) -> JsonDict: """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is joined, depending on the pagination config. Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left + user_id: The ID of the user making the request. + pagin_config: The pagination config used to determine how many + messages *PER ROOM* to return. + as_client_event: True to get events in client-server format. + include_archived: True to get rooms that the user has left Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. + A JsonDict with the same format as the response to `/intialSync` + API """ key = ( user_id, @@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler): async def _snapshot_all_rooms( self, - user_id=None, - pagin_config=None, - as_client_event=True, - include_archived=False, - ): + user_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + include_archived: bool = False, + ) -> JsonDict: memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -109,7 +113,7 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) @@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - async def handle_room(event): + async def handle_room(event: RoomsForUser): d = { "room_id": event.room_id, "membership": event.membership, @@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler): return ret - async def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync( + self, requester: Requester, room_id: str, pagin_config: PaginationConfig + ) -> JsonDict: """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. + requester: The user to get a snapshot for. + room_id: The room to get a snapshot of. + pagin_config: The pagination config used to determine how many + messages to return. Raises: AuthError if the user wasn't in the room. Returns: @@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler): return result async def _room_initial_sync_parted( - self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking - ): + self, + user_id: str, + room_id: str, + pagin_config: PaginationConfig, + membership: Membership, + member_event_id: str, + is_peeking: bool, + ) -> JsonDict: room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler): } async def _room_initial_sync_joined( - self, user_id, room_id, pagin_config, membership, is_peeking - ): + self, + user_id: str, + room_id: str, + pagin_config: PaginationConfig, + membership: Membership, + is_peeking: bool, + ) -> JsonDict: current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently @@ -360,7 +376,7 @@ class InitialSyncHandler(BaseHandler): current_state.values(), time_now ) - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 649ca1f08a..8a7b4916cd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,14 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional, Tuple +import random +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from six import iteritems, itervalues, string_types +from canonicaljson import encode_canonical_json -from canonicaljson import encode_canonical_json, json - -from twisted.internet import defer -from twisted.internet.defer import succeed from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -38,18 +35,22 @@ from synapse.api.errors import ( Codes, ConsentNotGivenError, NotFoundError, + ShadowBanError, SynapseError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events import EventBase +from synapse.events.builder import EventBuilder +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Collection, RoomAlias, UserID, create_requester +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester +from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -57,10 +58,13 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -class MessageHandler(object): +class MessageHandler: """Contains some read only APIs to get state about a room """ @@ -83,46 +87,57 @@ class MessageHandler(object): "_schedule_next_expiry", self._schedule_next_expiry ) - @defer.inlineCallbacks - def get_room_data( - self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False - ): + async def get_room_data( + self, user_id: str, room_id: str, event_type: str, state_key: str, + ) -> dict: """ Get data from a room. Args: - event : The room path event + user_id + room_id + event_type + state_key Returns: The path data content. Raises: - SynapseError if something went wrong. + SynapseError or AuthError if the user is not in the room """ ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - data = yield self.state.get_current_state(room_id, event_type, state_key) + data = await self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) + else: + # check_user_in_room_or_world_readable, if it doesn't raise an AuthError, should + # only ever return a Membership.JOIN/LEAVE object + # + # Safeguard in case it returned something else + logger.error( + "Attempted to retrieve data from a room for a user that has never been in it. " + "This should not have happened." + ) + raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN) return data - @defer.inlineCallbacks - def get_state_events( + async def get_state_events( self, - user_id, - room_id, - state_filter=StateFilter.all(), - at_token=None, - is_guest=False, - ): + user_id: str, + room_id: str, + state_filter: StateFilter = StateFilter.all(), + at_token: Optional[StreamToken] = None, + is_guest: bool = False, + ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. If an explicit @@ -130,15 +145,14 @@ class MessageHandler(object): visible. Args: - user_id(str): The user requesting state events. - room_id(str): The room ID to get all state events from. - state_filter (StateFilter): The state filter used to fetch state - from the database. - at_token(StreamToken|None): the stream token of the at which we are requesting + user_id: The user requesting state events. + room_id: The room ID to get all state events from. + state_filter: The state filter used to fetch state from the database. + at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest(bool): whether this user is a guest + is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -152,20 +166,20 @@ class MessageHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=at_token.room_key, limit=1 ) if not last_events: raise NotFoundError("Can't find event for token %s" % (at_token,)) - visible_events = yield filter_events_for_client( + visible_events = await filter_events_for_client( self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] if visible_events: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -179,23 +193,23 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - state_ids = yield self.store.get_filtered_current_state_ids( + state_ids = await self.store.get_filtered_current_state_ids( room_id, state_filter=state_filter ) - room_state = yield self.store.get_events(state_ids.values()) + room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( room_state.values(), now, # We don't bother bundling aggregations in when asked for state @@ -204,15 +218,14 @@ class MessageHandler(object): ) return events - @defer.inlineCallbacks - def get_joined_members(self, requester, room_id): + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. If the user has left the room return the state events from when they left. Args: - requester(Requester): The user requesting state events. - room_id(str): The room ID to get all state events from. + requester: The user requesting state events. + room_id: The room ID to get all state events from. Returns: A dict of user_id to profile info """ @@ -220,7 +233,7 @@ class MessageHandler(object): if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. - membership, _ = yield self.auth.check_user_in_room_or_world_readable( + membership, _ = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership != Membership.JOIN: @@ -228,7 +241,7 @@ class MessageHandler(object): "Getting joined members after leaving is not implemented" ) - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there @@ -246,10 +259,10 @@ class MessageHandler(object): "avatar_url": profile.avatar_url, "display_name": profile.display_name, } - for user_id, profile in iteritems(users_with_profile) + for user_id, profile in users_with_profile.items() } - def maybe_schedule_expiry(self, event): + def maybe_schedule_expiry(self, event: EventBase): """Schedule the expiry of an event if there's not already one scheduled, or if the one running is for an event that will expire after the provided timestamp. @@ -258,7 +271,7 @@ class MessageHandler(object): the master process, and therefore needs to be run on there. Args: - event (EventBase): The event to schedule the expiry of. + event: The event to schedule the expiry of. """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) @@ -269,8 +282,7 @@ class MessageHandler(object): # a task scheduled for a timestamp that's sooner than the provided one. self._schedule_expiry_for_event(event.event_id, expiry_ts) - @defer.inlineCallbacks - def _schedule_next_expiry(self): + async def _schedule_next_expiry(self): """Retrieve the ID and the expiry timestamp of the next event to be expired, and schedule an expiry task for it. @@ -278,18 +290,18 @@ class MessageHandler(object): future call to save_expiry_ts can schedule a new expiry task. """ # Try to get the expiry timestamp of the next event to expire. - res = yield self.store.get_next_event_to_expire() + res = await self.store.get_next_event_to_expire() if res: event_id, expiry_ts = res self._schedule_expiry_for_event(event_id, expiry_ts) - def _schedule_expiry_for_event(self, event_id, expiry_ts): + def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): """Schedule an expiry task for the provided event if there's not already one scheduled at a timestamp that's sooner than the provided one. Args: - event_id (str): The ID of the event to expire. - expiry_ts (int): The timestamp at which to expire the event. + event_id: The ID of the event to expire. + expiry_ts: The timestamp at which to expire the event. """ if self._scheduled_expiry: # If the provided timestamp refers to a time before the scheduled time of the @@ -319,8 +331,7 @@ class MessageHandler(object): event_id, ) - @defer.inlineCallbacks - def _expire_event(self, event_id): + async def _expire_event(self, event_id: str): """Retrieve and expire an event that needs to be expired from the database. If the event doesn't exist in the database, log it and delete the expiry date @@ -335,12 +346,12 @@ class MessageHandler(object): try: # Expire the event if we know about it. This function also deletes the expiry # date from the database in the same database transaction. - yield self.store.expire_event(event_id) + await self.store.expire_event(event_id) except Exception as e: logger.error("Could not expire event %s: %r", event_id, e) # Schedule the expiry of the next event to expire. - yield self._schedule_next_expiry() + await self._schedule_next_expiry() # The duration (in ms) after which rooms should be removed @@ -350,8 +361,8 @@ class MessageHandler(object): _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000 -class EventCreationHandler(object): - def __init__(self, hs): +class EventCreationHandler: + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -396,7 +407,7 @@ class EventCreationHandler(object): # # map from room id to time-of-last-attempt. # - self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are @@ -422,16 +433,15 @@ class EventCreationHandler(object): self._dummy_events_threshold = hs.config.dummy_events_threshold - @defer.inlineCallbacks - def create_event( + async def create_event( self, - requester, - event_dict, - token_id=None, - txn_id=None, - prev_event_ids: Optional[Collection[str]] = None, - require_consent=True, - ): + requester: Requester, + event_dict: dict, + token_id: Optional[str] = None, + txn_id: Optional[str] = None, + prev_event_ids: Optional[List[str]] = None, + require_consent: bool = True, + ) -> Tuple[EventBase, EventContext]: """ Given a dict from a client, create a new event. @@ -442,31 +452,29 @@ class EventCreationHandler(object): Args: requester - event_dict (dict): An entire event - token_id (str) - txn_id (str) - + event_dict: An entire event + token_id + txn_id prev_event_ids: the forward extremities to use as the prev_events for the new event. If None, they will be requested from the database. - - require_consent (bool): Whether to check if the requester has - consented to privacy policy. + require_consent: Whether to check if the requester has + consented to the privacy policy. Raises: ResourceLimitError if server is blocked to some resource being exceeded Returns: - Tuple of created event (FrozenEvent), Context + Tuple of created event, Context """ - yield self.auth.check_auth_blocking(requester.user.to_string()) + await self.auth.check_auth_blocking(requester.user.to_string()) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version_id( + room_version = await self.store.get_room_version_id( event_dict["room_id"] ) except NotFoundError: @@ -487,11 +495,11 @@ class EventCreationHandler(object): try: if "displayname" not in content: - displayname = yield profile.get_displayname(target) + displayname = await profile.get_displayname(target) if displayname is not None: content["displayname"] = displayname if "avatar_url" not in content: - avatar_url = yield profile.get_avatar_url(target) + avatar_url = await profile.get_avatar_url(target) if avatar_url is not None: content["avatar_url"] = avatar_url except Exception as e: @@ -499,9 +507,9 @@ class EventCreationHandler(object): "Failed to get profile information for %r: %s", target, e ) - is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) + is_exempt = await self._is_exempt_from_privacy_policy(builder, requester) if require_consent and not is_exempt: - yield self.assert_accepted_privacy_policy(requester) + await self.assert_accepted_privacy_policy(requester) if token_id is not None: builder.internal_metadata.token_id = token_id @@ -509,7 +517,7 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - event, context = yield self.create_new_client_event( + event, context = await self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -525,10 +533,10 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( - yield self.store.get_event(prev_event_id, allow_none=True) + await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id else None ) @@ -551,37 +559,36 @@ class EventCreationHandler(object): return (event, context) - def _is_exempt_from_privacy_policy(self, builder, requester): + async def _is_exempt_from_privacy_policy( + self, builder: EventBuilder, requester: Requester + ) -> bool: """"Determine if an event to be sent is exempt from having to consent to the privacy policy Args: - builder (synapse.events.builder.EventBuilder): event being created - requester (Requster): user requesting this event + builder: event being created + requester: user requesting this event Returns: - Deferred[bool]: true if the event can be sent without the user - consenting + true if the event can be sent without the user consenting """ # the only thing the user can do is join the server notices room. if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return self._is_server_notices_room(builder.room_id) + return await self._is_server_notices_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() - return succeed(False) + return False - @defer.inlineCallbacks - def _is_server_notices_room(self, room_id): + async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.server_notices_mxid is None: return False - user_ids = yield self.store.get_users_in_room(room_id) + user_ids = await self.store.get_users_in_room(room_id) return self.config.server_notices_mxid in user_ids - @defer.inlineCallbacks - def assert_accepted_privacy_policy(self, requester): + async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy Called when the given user is about to do something that requires @@ -590,12 +597,10 @@ class EventCreationHandler(object): raised. Args: - requester (synapse.types.Requester): - The user making the request + requester: The user making the request Returns: - Deferred[None]: returns normally if the user has consented or is - exempt + Returns normally if the user has consented or is exempt Raises: ConsentNotGivenError: if the user has not given consent yet @@ -616,7 +621,7 @@ class EventCreationHandler(object): ): return - u = yield self.store.get_user_by_id(user_id) + u = await self.store.get_user_by_id(user_id) assert u is not None if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent @@ -634,74 +639,115 @@ class EventCreationHandler(object): raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) async def send_nonmember_event( - self, requester, event, context, ratelimit=True + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + ignore_shadow_ban: bool = False, ) -> int: """ Persists and notifies local clients and federation of an event. Args: - event (FrozenEvent) the event to send. - context (Context) the context of the event. - ratelimit (bool): Whether to rate limit this send. - is_guest (bool): Whether the sender is a guest. + requester: The requester sending the event. + event: The event to send. + context: The context of the event. + ratelimit: Whether to rate limit this send. + ignore_shadow_ban: True if shadow-banned users should be allowed to + send this event. Return: The stream_id of the persisted event. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ if event.type == EventTypes.Member: raise SynapseError( 500, "Tried to send member event through non-member codepath" ) + if not ignore_shadow_ban and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = await self.deduplicate_state_event(event, context) - if prev_state is not None: + prev_event = await self.deduplicate_state_event(event, context) + if prev_event is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", event.event_id, - prev_state.event_id, + prev_event.event_id, ) - return prev_state + return await self.store.get_stream_id_for_event(prev_event.event_id) return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) - @defer.inlineCallbacks - def deduplicate_state_event(self, event, context): + async def deduplicate_state_event( + self, event: EventBase, context: EventContext + ) -> Optional[EventBase]: """ Checks whether event is in the latest resolved state in context. - If so, returns the version of the event in context. - Otherwise, returns None. + Args: + event: The event to check for duplication. + context: The event context. + + Returns: + The previous verion of the event is returned, if it is found in the + event context. Otherwise, None is returned. """ - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: - return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + return None + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if not prev_event: - return + return None if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: return prev_event - return + return None async def create_and_send_nonmember_event( - self, requester, event_dict, ratelimit=True, txn_id=None + self, + requester: Requester, + event_dict: dict, + ratelimit: bool = True, + txn_id: Optional[str] = None, + ignore_shadow_ban: bool = False, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. See self.create_event and self.send_nonmember_event. + + Args: + requester: The requester sending the event. + event_dict: An entire event. + ratelimit: Whether to rate limit this send. + txn_id: The transaction ID. + ignore_shadow_ban: True if shadow-banned users should be allowed to + send this event. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ + if not ignore_shadow_ban and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() # We limit the number of concurrent event sends in a room so that we # don't fork the DAG too much. If we don't limit then we can end up in @@ -715,27 +761,31 @@ class EventCreationHandler(object): spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: - if not isinstance(spam_error, string_types): + if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) stream_id = await self.send_nonmember_event( - requester, event, context, ratelimit=ratelimit + requester, + event, + context, + ratelimit=ratelimit, + ignore_shadow_ban=ignore_shadow_ban, ) return event, stream_id @measure_func("create_new_client_event") - @defer.inlineCallbacks - def create_new_client_event( - self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None - ): + async def create_new_client_event( + self, + builder: EventBuilder, + requester: Optional[Requester] = None, + prev_event_ids: Optional[List[str]] = None, + ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client Args: - builder (EventBuilder): - - requester (synapse.types.Requester|None): - + builder: + requester: prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -743,7 +793,7 @@ class EventCreationHandler(object): If None, they will be requested from the database. Returns: - Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + Tuple of created event, context """ if prev_event_ids is not None: @@ -752,10 +802,19 @@ class EventCreationHandler(object): % (len(prev_event_ids),) ) else: - prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) + prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - event = yield builder.build(prev_event_ids=prev_event_ids) - context = yield self.state.compute_event_context(event) + # we now ought to have some prev_events (unless it's a create event). + # + # do a quick sanity check here, rather than waiting until we've created the + # event and then try to auth it (which fails with a somewhat confusing "No + # create event in auth events") + assert ( + builder.type == EventTypes.Create or len(prev_event_ids) > 0 + ), "Attempting to create an event with no prev_events" + + event = await builder.build(prev_event_ids=prev_event_ids) + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -769,7 +828,7 @@ class EventCreationHandler(object): relates_to = relation["event_id"] aggregation_key = relation["key"] - already_exists = yield self.store.has_user_annotated_event( + already_exists = await self.store.has_user_annotated_event( relates_to, event.type, aggregation_key, event.sender ) if already_exists: @@ -781,7 +840,12 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -790,11 +854,11 @@ class EventCreationHandler(object): processing. Args: - requester (Requester) - event (FrozenEvent) - context (EventContext) - ratelimit (bool) - extra_users (list(UserID)): Any extra users to notify about event + requester + event + context + ratelimit + extra_users: Any extra users to notify about event Return: The stream_id of the persisted event. @@ -816,25 +880,28 @@ class EventCreationHandler(object): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - try: - await self.auth.check_from_context(room_version, event, context) - except AuthError as err: - logger.warning("Denying new event %r because %s", event, err) - raise err + if event.internal_metadata.is_out_of_band_membership(): + # the only sort of out-of-band-membership events we expect to see here + # are invite rejections we have generated ourselves. + assert event.type == EventTypes.Member + assert event.content["membership"] == Membership.LEAVE + else: + try: + await self.auth.check_from_context(room_version, event, context) + except AuthError as err: + logger.warning("Denying new event %r because %s", event, err) + raise err # Ensure that we can round trip before trying to persist in db try: dump = frozendict_json_encoder.encode(event.content) - json.loads(dump) + json_decoder.decode(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) raise await self.action_generator.handle_push_actions_for_event(event, context) - # reraise does not allow inlineCallbacks to preserve the stacktrace, so we - # hack around with a try/finally instead. - success = False try: # If we're a worker we need to hit out to the master. if not self._is_event_writer: @@ -850,27 +917,22 @@ class EventCreationHandler(object): ) stream_id = result["stream_id"] event.internal_metadata.stream_ordering = stream_id - success = True return stream_id stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - success = True return stream_id - finally: - if not success: - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - run_in_background( - self.store.remove_push_actions_from_staging, event.event_id - ) + except Exception: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + await self.store.remove_push_actions_from_staging(event.event_id) + raise - @defer.inlineCallbacks - def _validate_canonical_alias( - self, directory_handler, room_alias_str, expected_room_id - ): + async def _validate_canonical_alias( + self, directory_handler, room_alias_str: str, expected_room_id: str + ) -> None: """ Ensure that the given room alias points to the expected room ID. @@ -881,7 +943,7 @@ class EventCreationHandler(object): """ room_alias = RoomAlias.from_string(room_alias_str) try: - mapping = yield directory_handler.get_association(room_alias) + mapping = await directory_handler.get_association(room_alias) except SynapseError as e: # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. if e.errcode == Codes.NOT_FOUND: @@ -900,7 +962,12 @@ class EventCreationHandler(object): ) async def persist_and_notify_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -924,7 +991,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -938,7 +1005,7 @@ class EventCreationHandler(object): # Validate a newly added alias or newly added alt_aliases. original_alias = None - original_alt_aliases = set() + original_alt_aliases = [] # type: List[str] original_event_id = event.unsigned.get("replaces_state") if original_event_id: @@ -986,9 +1053,13 @@ class EventCreationHandler(object): current_state_ids = await context.get_current_state_ids() + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + state_to_include_ids = [ e_id - for k, e_id in iteritems(current_state_ids) + for k, e_id in current_state_ids.items() if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -1002,7 +1073,7 @@ class EventCreationHandler(object): "content": e.content, "sender": e.sender, } - for e in itervalues(state_to_include) + for e in state_to_include.values() ] invitee = UserID.from_string(event.state_key) @@ -1037,11 +1108,11 @@ class EventCreationHandler(object): raise SynapseError(400, "Cannot redact event from a different room") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -1093,7 +1164,7 @@ class EventCreationHandler(object): return event_stream_id - async def _bump_active_time(self, user): + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() await presence.bump_presence_active_time(user) @@ -1139,8 +1210,14 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False + # Since this is a dummy-event it is OK if it is sent by a + # shadow-banned user. await self.send_nonmember_event( - requester, event, context, ratelimit=False + requester, + event, + context, + ratelimit=False, + ignore_shadow_ban=True, ) dummy_event_sent = True break diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 9c08eb5399..1b06f3173f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging -from typing import Dict, Generic, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode import attr @@ -35,12 +34,14 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.push.mailer import load_jinja2_templates -from synapse.server import HomeServer from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.util import json_decoder + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -91,7 +92,8 @@ class OidcHandler: """Handles requests related to the OpenID Connect login flow. """ - def __init__(self, hs: HomeServer): + def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -121,9 +123,7 @@ class OidcHandler: self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_error.html"] - )[0] + self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" @@ -144,15 +144,10 @@ class OidcHandler: access_denied. error_description: A human-readable description of the error. """ - html_bytes = self._error_template.render( + html = self._error_template.render( error=error, error_description=error_description - ).encode("utf-8") - - request.setResponseCode(400) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) - request.write(html_bytes) - finish_request(request) + ) + respond_with_html(request, 400, html) def _validate_metadata(self): """Verifies the provider metadata. @@ -373,7 +368,7 @@ class OidcHandler: # and check for an error field. If not, we respond with a generic # error message. try: - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) error = resp["error"] description = resp.get("error_description", error) except (ValueError, KeyError): @@ -390,7 +385,7 @@ class OidcHandler: # Since it is a not a 5xx code, body should be a valid JSON. It will # raise if not. - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) if "error" in resp: error = resp["error"] @@ -695,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -834,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -849,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -862,6 +869,9 @@ class OidcHandler: raise MappingException( "Failed to extract subject from OIDC response: %s" % (e,) ) + # Some OIDC providers use integer IDs, but Synapse expects external IDs + # to be strings. + remote_user_id = str(remote_user_id) logger.info( "Looking for existing mapping for user %s:%s", @@ -905,7 +915,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index d7442c62a7..34ed0e2921 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,26 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Any, Dict, Optional, Set -from six import iteritems - -from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.filtering import Filter from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter -from synapse.types import RoomStreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Requester, RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) -class PurgeStatus(object): +class PurgeStatus: """Object tracking the status of a purge request This class contains information on the progress of a purge request, for @@ -61,14 +65,14 @@ class PurgeStatus(object): return {"status": PurgeStatus.STATUS_TEXT[self.status]} -class PaginationHandler(object): +class PaginationHandler: """Handles pagination and purge history requests. These are in the same handler due to the fact we need to block clients paginating during a purge. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -78,13 +82,16 @@ class PaginationHandler(object): self._server_name = hs.hostname self.pagination_lock = ReadWriteLock() - self._purges_in_progress_by_room = set() + self._purges_in_progress_by_room = set() # type: Set[str] # map from purge id to PurgeStatus - self._purges_by_id = {} + self._purges_by_id = {} # type: Dict[str, PurgeStatus] self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min + self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max + if hs.config.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.retention_purge_jobs: @@ -99,8 +106,9 @@ class PaginationHandler(object): job["longest_max_lifetime"], ) - @defer.inlineCallbacks - def purge_history_for_rooms_in_range(self, min_ms, max_ms): + async def purge_history_for_rooms_in_range( + self, min_ms: Optional[int], max_ms: Optional[int] + ): """Purge outdated events from rooms within the given retention range. If a default retention policy is defined in the server's configuration and its @@ -108,14 +116,14 @@ class PaginationHandler(object): retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, it means that the range has no lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, it means that the range has no upper limit. """ - # We want the storage layer to to include rooms with no retention policy in its + # We want the storage layer to include rooms with no retention policy in its # return value only if a default retention policy is defined in the server's # configuration and that policy's 'max_lifetime' is either lower (or equal) than # max_ms or higher than min_ms (or both). @@ -139,13 +147,13 @@ class PaginationHandler(object): include_null, ) - rooms = yield self.store.get_rooms_for_retention_period_in_range( + rooms = await self.store.get_rooms_for_retention_period_in_range( min_ms, max_ms, include_null ) logger.debug("[purge] Rooms to purge: %s", rooms) - for room_id, retention_policy in iteritems(rooms): + for room_id, retention_policy in rooms.items(): logger.info("[purge] Attempting to purge messages in room %s", room_id) if room_id in self._purges_in_progress_by_room: @@ -156,20 +164,39 @@ class PaginationHandler(object): ) continue - max_lifetime = retention_policy["max_lifetime"] + # If max_lifetime is None, it means that the room has no retention policy. + # Given we only retrieve such rooms when there's a default retention policy + # defined in the server's configuration, we can safely assume that's the + # case and use it for this room. + max_lifetime = ( + retention_policy["max_lifetime"] or self._retention_default_max_lifetime + ) - if max_lifetime is None: - # If max_lifetime is None, it means that include_null equals True, - # therefore we can safely assume that there is a default policy defined - # in the server's configuration. - max_lifetime = self._retention_default_max_lifetime + # Cap the effective max_lifetime to be within the range allowed in the + # config. + # We do this in two steps: + # 1. Make sure it's higher or equal to the minimum allowed value, and if + # it's not replace it with that value. This is because the server + # operator can be required to not delete information before a given + # time, e.g. to comply with freedom of information laws. + # 2. Make sure the resulting value is lower or equal to the maximum allowed + # value, and if it's not replace it with that value. This is because the + # server operator can be required to delete any data after a specific + # amount of time. + if self._retention_allowed_lifetime_min is not None: + max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime) + + if self._retention_allowed_lifetime_max is not None: + max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max) + + logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime) # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime - stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) + stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) - r = yield self.store.get_room_event_before_stream_ordering( + r = await self.store.get_room_event_before_stream_ordering( room_id, stream_ordering, ) if not r: @@ -199,18 +226,19 @@ class PaginationHandler(object): "_purge_history", self._purge_history, purge_id, room_id, token, True, ) - def start_purge_history(self, room_id, token, delete_local_events=False): + def start_purge_history( + self, room_id: str, token: str, delete_local_events: bool = False + ) -> str: """Start off a history purge on a room. Args: - room_id (str): The room to purge from - - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones Returns: - str: unique ID for this purge transaction. + unique ID for this purge transaction. """ if room_id in self._purges_in_progress_by_room: raise SynapseError( @@ -229,24 +257,21 @@ class PaginationHandler(object): ) return purge_id - @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, delete_local_events): + async def _purge_history( + self, purge_id: str, room_id: str, token: str, delete_local_events: bool + ) -> None: """Carry out a history purge on a room. Args: - purge_id (str): The id for this purge - room_id (str): The room to purge from - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as - remote ones - - Returns: - Deferred + purge_id: The id for this purge + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones """ self._purges_in_progress_by_room.add(room_id) try: - with (yield self.pagination_lock.write(room_id)): - yield self.storage.purge_events.purge_history( + with await self.pagination_lock.write(room_id): + await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -266,27 +291,22 @@ class PaginationHandler(object): self.hs.get_reactor().callLater(24 * 3600, clear_purge) - def get_purge_status(self, purge_id): + def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: """Get the current status of an active purge Args: - purge_id (str): purge_id returned by start_purge_history - - Returns: - PurgeStatus|None + purge_id: purge_id returned by start_purge_history """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> None: """Purge the given room from the database""" - with (await self.pagination_lock.write(room_id)): + with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) # first check that we have no users in this room - joined = await defer.maybeDeferred( - self.store.is_host_joined, room_id, self._server_name - ) + joined = await self.store.is_host_joined(room_id, self._server_name) if joined: raise SynapseError(400, "Users are still joined to this room") @@ -295,23 +315,22 @@ class PaginationHandler(object): async def get_messages( self, - requester, - room_id=None, - pagin_config=None, - as_client_event=True, - event_filter=None, - ): + requester: Requester, + room_id: str, + pagin_config: PaginationConfig, + as_client_event: bool = True, + event_filter: Optional[Filter] = None, + ) -> Dict[str, Any]: """Get messages in a room. Args: - requester (Requester): The user requesting messages. - room_id (str): The room they want messages from. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. - as_client_event (bool): True to get events in client-server format. - event_filter (Filter): Filter to apply to results or None + requester: The user requesting messages. + room_id: The room they want messages from. + pagin_config: The pagination config rules to apply, if any. + as_client_event: True to get events in client-server format. + event_filter: Filter to apply to results or None Returns: - dict: Pagination API results + Pagination API results """ user_id = requester.user.to_string() @@ -319,7 +338,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - await self.hs.get_event_sources().get_current_token_for_pagination() + self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -331,7 +350,7 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (await self.pagination_lock.read(room_id)): + with await self.pagination_lock.read(room_id): ( membership, member_event_id, @@ -353,11 +372,15 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. + + # This is only None if the room is world_readable, in which + # case "JOIN" would have been returned. + assert member_event_id + leave_token = await self.store.get_topological_token_for_event( member_event_id ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: + if RoomStreamToken.parse(leave_token).topological < max_topo: source_config.from_key = str(leave_token) await self.hs.get_handlers().federation_handler.maybe_backfill( @@ -404,8 +427,8 @@ class PaginationHandler(object): ) if state_ids: - state = await self.store.get_events(list(state_ids.values())) - state = state.values() + state_dict = await self.store.get_events(list(state_ids.values())) + state = state_dict.values() time_now = self.clock.time_msec() diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py index d06b110269..88e2f87200 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py @@ -22,7 +22,7 @@ from synapse.api.errors import Codes, PasswordRefusedError logger = logging.getLogger(__name__) -class PasswordPolicyHandler(object): +class PasswordPolicyHandler: def __init__(self, hs): self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3594f3b00f..91a3aec1cc 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,24 +25,22 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set - -from six import iteritems, itervalues +from typing import Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError +from synapse.api.presence import UserPresenceState from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.state import StateHandler +from synapse.storage.databases.main import DataStore +from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC): for user_id in user_ids } - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) for user_id in missing @@ -321,7 +319,7 @@ class PresenceHandler(BasePresenceHandler): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.store.db.is_running(): + if not self.store.db_pool.is_running(): return logger.info( @@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) - for prev_state in itervalues(prev_states) + for prev_state in prev_states.values() ] ) self.external_process_last_updated_ms.pop(process_id, None) @@ -775,7 +773,9 @@ class PresenceHandler(BasePresenceHandler): return False - async def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -787,10 +787,31 @@ class PresenceHandler(BasePresenceHandler): - last_user_sync_ts(int) - status_msg(int) - currently_active(int) + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id, limit) + rows = await self.store.get_all_presence_updates( + instance_name, last_id, current_id, limit + ) return rows def notify_new_event(self): @@ -874,16 +895,9 @@ class PresenceHandler(BasePresenceHandler): await self._on_user_joined_room(room_id, state_key) - async def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: """Called when we detect a user joining the room via the current state delta stream. - - Args: - room_id (str) - user_id (str) - - Returns: - Deferred """ if self.is_mine_id(user_id): @@ -914,8 +928,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. - user_ids = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + users = await self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, users)) states_d = await self.current_state_for_users(user_ids) @@ -996,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True): return content -class PresenceEventSource(object): +class PresenceEventSource: def __init__(self, hs): # We can't call get_presence_handler here because there's a cycle: # @@ -1087,7 +1101,7 @@ class PresenceEventSource(object): return (list(updates.values()), max_token) else: return ( - [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE], + [s for s in updates.values() if s.state != PresenceState.OFFLINE], max_token, ) @@ -1275,22 +1289,24 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): return new_state, persist_and_notify, federation_ping -@defer.inlineCallbacks -def get_interested_parties(store, states): +async def get_interested_parties( + store: DataStore, states: List[UserPresenceState] +) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: """Given a list of states return which entities (rooms, users) are interested in the given states. Args: - states (list(UserPresenceState)) + store + states Returns: - 2-tuple: `(room_ids_to_states, users_to_states)`, + A 2-tuple of `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: - room_ids = yield store.get_rooms_for_user(state.user_id) + room_ids = await store.get_rooms_for_user(state.user_id) for room_id in room_ids: room_ids_to_states.setdefault(room_id, []).append(state) @@ -1300,34 +1316,36 @@ def get_interested_parties(store, states): return room_ids_to_states, users_to_states -@defer.inlineCallbacks -def get_interested_remotes(store, states, state_handler): +async def get_interested_remotes( + store: DataStore, states: List[UserPresenceState], state_handler: StateHandler +) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. All the presence states should be for local users only. Args: - store (DataStore) - states (list(UserPresenceState)) + store + states + state_handler Returns: - Deferred list of ([destinations], [UserPresenceState]), where for - each row the list of UserPresenceState should be sent to each + A list of 2-tuples of destinations and states, where for + each tuple the list of UserPresenceState should be sent to each destination """ - hosts_and_states = [] + hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = yield get_interested_parties(store, states) + room_ids_to_states, users_to_states = await get_interested_parties(store, states) - for room_id, states in iteritems(room_ids_to_states): - hosts = yield state_handler.get_current_hosts_in_room(room_id) + for room_id, states in room_ids_to_states.items(): + hosts = await state_handler.get_current_hosts_in_room(room_id) hosts_and_states.append((hosts, states)) - for user_id, states in iteritems(users_to_states): + for user_id, states in users_to_states.items(): host = get_domain_from_id(user_id) hosts_and_states.append(([host], states)) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 302efc1b9a..0cb8fad89a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -14,10 +14,7 @@ # limitations under the License. import logging - -from six import raise_from - -from twisted.internet import defer +import random from synapse.api.errors import ( AuthError, @@ -56,16 +53,15 @@ class BaseProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def get_profile(self, user_id): + async def get_profile(self, user_id): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -76,7 +72,7 @@ class BaseProfileHandler(BaseHandler): return {"displayname": displayname, "avatar_url": avatar_url} else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": user_id}, @@ -84,12 +80,11 @@ class BaseProfileHandler(BaseHandler): ) return result except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() - @defer.inlineCallbacks - def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id): """Get the profile information from our local cache. If the user is ours then the profile information will always be corect. Otherwise, it may be out of date/missing. @@ -97,10 +92,10 @@ class BaseProfileHandler(BaseHandler): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -110,14 +105,13 @@ class BaseProfileHandler(BaseHandler): return {"displayname": displayname, "avatar_url": avatar_url} else: - profile = yield self.store.get_from_remote_profile_cache(user_id) + profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - @defer.inlineCallbacks - def get_displayname(self, target_user): + async def get_displayname(self, target_user): if self.hs.is_mine(target_user): try: - displayname = yield self.store.get_profile_displayname( + displayname = await self.store.get_profile_displayname( target_user.localpart ) except StoreError as e: @@ -128,14 +122,14 @@ class BaseProfileHandler(BaseHandler): return displayname else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "displayname"}, ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -167,6 +161,9 @@ class BaseProfileHandler(BaseHandler): Codes.FORBIDDEN, ) + if not isinstance(new_displayname, str): + raise SynapseError(400, "Invalid displayname") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -191,11 +188,10 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - @defer.inlineCallbacks - def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): try: - avatar_url = yield self.store.get_profile_avatar_url( + avatar_url = await self.store.get_profile_avatar_url( target_user.localpart ) except StoreError as e: @@ -205,14 +201,14 @@ class BaseProfileHandler(BaseHandler): return avatar_url else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=target_user.domain, query_type="profile", args={"user_id": target_user.to_string(), "field": "avatar_url"}, ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -221,8 +217,14 @@ class BaseProfileHandler(BaseHandler): async def set_avatar_url( self, target_user, requester, new_avatar_url, by_admin=False ): - """target_user is the user whose avatar_url is to be changed; - auth_user is the user attempting to make this change.""" + """Set a new avatar URL for a user. + + Args: + target_user (UserID): the user whose avatar URL is to be changed. + requester (Requester): The user attempting to make this change. + new_avatar_url (str): The avatar URL to give this user. + by_admin (bool): Whether this change was made by an administrator. + """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -236,6 +238,9 @@ class BaseProfileHandler(BaseHandler): 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN ) + if not isinstance(new_avatar_url, str): + raise SynapseError(400, "Invalid displayname") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) @@ -255,8 +260,7 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - @defer.inlineCallbacks - def on_profile_query(self, args): + async def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -266,12 +270,12 @@ class BaseProfileHandler(BaseHandler): response = {} try: if just_field is None or just_field == "displayname": - response["displayname"] = yield self.store.get_profile_displayname( + response["displayname"] = await self.store.get_profile_displayname( user.localpart ) if just_field is None or just_field == "avatar_url": - response["avatar_url"] = yield self.store.get_profile_avatar_url( + response["avatar_url"] = await self.store.get_profile_avatar_url( user.localpart ) except StoreError as e: @@ -287,6 +291,12 @@ class BaseProfileHandler(BaseHandler): await self.ratelimit(requester) + # Do not actually update the room state for shadow-banned users. + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + return + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: @@ -306,8 +316,7 @@ class BaseProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e) ) - @defer.inlineCallbacks - def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed(self, target_user, requester=None): """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users @@ -339,8 +348,8 @@ class BaseProfileHandler(BaseHandler): return try: - requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) - target_user_rooms = yield self.store.get_rooms_for_user( + requester_rooms = await self.store.get_rooms_for_user(requester.to_string()) + target_user_rooms = await self.store.get_rooms_for_user( target_user.to_string() ) @@ -373,25 +382,24 @@ class MasterProfileHandler(BaseProfileHandler): "Update remote profile", self._update_remote_profile_cache ) - @defer.inlineCallbacks - def _update_remote_profile_cache(self): + async def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. """ - entries = yield self.store.get_remote_profile_cache_entries_that_expire( + entries = await self.store.get_remote_profile_cache_entries_that_expire( last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS ) for user_id, displayname, avatar_url in entries: - is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( + is_subscribed = await self.store.is_subscribed_remote_profile_for_user( user_id ) if not is_subscribed: - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) continue try: - profile = yield self.federation.make_query( + profile = await self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", args={"user_id": user_id}, @@ -400,7 +408,7 @@ class MasterProfileHandler(BaseProfileHandler): except Exception: logger.exception("Failed to get avatar_url") - yield self.store.update_remote_profile_cache( + await self.store.update_remote_profile_cache( user_id, displayname, avatar_url ) continue @@ -409,4 +417,4 @@ class MasterProfileHandler(BaseProfileHandler): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) + await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 8bc100db42..2cc6c2eb68 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id from synapse.util.async_helpers import maybe_awaitable @@ -125,19 +123,18 @@ class ReceiptsHandler(BaseHandler): await self.federation.send_read_receipt(receipt) -class ReceiptEventSource(object): +class ReceiptEventSource: def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) - to_key = yield self.get_current_key() + to_key = self.get_current_key() if from_key == to_key: return [], to_key - events = yield self.store.get_linearized_receipts_for_rooms( + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) @@ -146,8 +143,7 @@ class ReceiptEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): to_key = int(config.from_key) if config.to_key: @@ -155,8 +151,8 @@ class ReceiptEventSource(object): else: from_key = None - room_ids = yield self.store.get_rooms_for_user(user.to_string()) - events = yield self.store.get_linearized_receipts_for_rooms( + room_ids = await self.store.get_rooms_for_user(user.to_string()) + events = await self.store.get_linearized_receipts_for_rooms( room_ids, from_key=from_key, to_key=to_key ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 51979ea43e..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -17,7 +17,7 @@ import logging from synapse import types -from synapse.api.constants import MAX_USERID_LENGTH, LoginType +from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict @@ -26,8 +26,9 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) -from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async_helpers import Linearizer +from synapse.spam_checker_api import RegistrationBehaviour +from synapse.storage.state import StateFilter +from synapse.types import RoomAlias, UserID, create_requester from ._base import BaseHandler @@ -49,16 +50,11 @@ class RegistrationHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() self.identity_handler = self.hs.get_handlers().identity_handler self.ratelimiter = hs.get_registration_ratelimiter() - - self._next_generated_user_id = None - self.macaroon_gen = hs.get_macaroon_generator() - - self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer" - ) self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -131,7 +127,9 @@ class RegistrationHandler(BaseHandler): try: int(localpart) raise SynapseError( - 400, "Numeric user IDs are reserved for guest users." + 400, + "Numeric user IDs are reserved for guest users.", + errcode=Codes.INVALID_USERNAME, ) except ValueError: pass @@ -149,6 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -166,6 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -173,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) @@ -201,6 +220,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) if self.hs.config.user_directory_search_all_users: @@ -218,7 +238,7 @@ class RegistrationHandler(BaseHandler): if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self._generate_user_id() + localpart = await self.store.generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) @@ -231,6 +251,7 @@ class RegistrationHandler(BaseHandler): make_guest=make_guest, create_profile_with_displayname=default_display_name, address=address, + shadow_banned=shadow_banned, ) # Successfully registered @@ -270,51 +291,157 @@ class RegistrationHandler(BaseHandler): return user_id - async def _auto_join_rooms(self, user_id): - """Automatically joins users to auto join rooms - creating the room in the first place - if the user is the first to be created. + async def _create_and_join_rooms(self, user_id: str): + """ + Create the auto-join rooms and join or invite the user to them. + + This should only be called when the first "real" user registers. Args: - user_id(str): The user to join + user_id: The user to join """ - # auto-join the user to any rooms we're supposed to dump them into - fake_requester = create_requester(user_id) + # Getting the handlers during init gives a dependency loop. + room_creation_handler = self.hs.get_room_creation_handler() + room_member_handler = self.hs.get_room_member_handler() - # try to create the room if we're the first real user on the server. Note - # that an auto-generated support or bot user is not a real user and will never be - # the user to create the room - should_auto_create_rooms = False - is_real_user = await self.store.is_real_user(user_id) - if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = await self.store.count_real_users() - should_auto_create_rooms = count == 1 - for r in self.hs.config.auto_join_rooms: + # Generate a stub for how the rooms will be configured. + stub_config = { + "preset": self.hs.config.registration.autocreate_auto_join_room_preset, + } + + # If the configuration providers a user ID to create rooms with, use + # that instead of the first user registered. + requires_join = False + if self.hs.config.registration.auto_join_user_id: + fake_requester = create_requester( + self.hs.config.registration.auto_join_user_id + ) + + # If the room requires an invite, add the user to the list of invites. + if self.hs.config.registration.auto_join_room_requires_invite: + stub_config["invite"] = [user_id] + + # If the room is being created by a different user, the first user + # registered needs to join it. Note that in the case of an invitation + # being necessary this will occur after the invite was sent. + requires_join = True + else: + fake_requester = create_requester(user_id) + + # Choose whether to federate the new room. + if not self.hs.config.registration.autocreate_auto_join_rooms_federated: + stub_config["creation_content"] = {"m.federate": False} + + for r in self.hs.config.registration.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) + try: - if should_auto_create_rooms: - room_alias = RoomAlias.from_string(r) - if self.hs.hostname != room_alias.domain: - logger.warning( - "Cannot create room alias %s, " - "it does not match server domain", - r, - ) - else: - # create room expects the localpart of the room alias - room_alias_localpart = room_alias.localpart - - # getting the RoomCreationHandler during init gives a dependency - # loop - await self.hs.get_room_creation_handler().create_room( - fake_requester, - config={ - "preset": "public_chat", - "room_alias_name": room_alias_localpart, - }, + room_alias = RoomAlias.from_string(r) + + if self.hs.hostname != room_alias.domain: + logger.warning( + "Cannot create room alias %s, " + "it does not match server domain", + r, + ) + else: + # A shallow copy is OK here since the only key that is + # modified is room_alias_name. + config = stub_config.copy() + # create room expects the localpart of the room alias + config["room_alias_name"] = room_alias.localpart + + info, _ = await room_creation_handler.create_room( + fake_requester, config=config, ratelimit=False, + ) + + # If the room does not require an invite, but another user + # created it, then ensure the first user joins it. + if requires_join: + await room_member_handler.update_membership( + requester=create_requester(user_id), + target=UserID.from_string(user_id), + room_id=info["room_id"], + # Since it was just created, there are no remote hosts. + remote_room_hosts=[], + action="join", ratelimit=False, ) + + except ConsentNotGivenError as e: + # Technically not necessary to pull out this error though + # moving away from bare excepts is a good thing to do. + logger.error("Failed to join new user to %r: %r", r, e) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + + async def _join_rooms(self, user_id: str): + """ + Join or invite the user to the auto-join rooms. + + Args: + user_id: The user to join + """ + room_member_handler = self.hs.get_room_member_handler() + + for r in self.hs.config.registration.auto_join_rooms: + logger.info("Auto-joining %s to %s", user_id, r) + + try: + room_alias = RoomAlias.from_string(r) + + if RoomAlias.is_valid(r): + ( + room_id, + remote_room_hosts, + ) = await room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() else: - await self._join_user_to_room(fake_requester, r) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (r,) + ) + + # Calculate whether the room requires an invite or can be + # joined directly. Note that unless a join rule of public exists, + # it is treated as requiring an invite. + requires_invite = True + + state = await self.store.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) + ) + + event_id = state.get((EventTypes.JoinRules, "")) + if event_id: + join_rules_event = await self.store.get_event( + event_id, allow_none=True + ) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + requires_invite = join_rule and join_rule != JoinRules.PUBLIC + + # Send the invite, if necessary. + if requires_invite: + await room_member_handler.update_membership( + requester=create_requester( + self.hs.config.registration.auto_join_user_id + ), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="invite", + ratelimit=False, + ) + + # Send the join. + await room_member_handler.update_membership( + requester=create_requester(user_id), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ratelimit=False, + ) + except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -322,6 +449,29 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) + async def _auto_join_rooms(self, user_id: str): + """Automatically joins users to auto join rooms - creating the room in the first place + if the user is the first to be created. + + Args: + user_id: The user to join + """ + # auto-join the user to any rooms we're supposed to dump them into + + # try to create the room if we're the first real user on the server. Note + # that an auto-generated support or bot user is not a real user and will never be + # the user to create the room + should_auto_create_rooms = False + is_real_user = await self.store.is_real_user(user_id) + if self.hs.config.registration.autocreate_auto_join_rooms and is_real_user: + count = await self.store.count_real_users() + should_auto_create_rooms = count == 1 + + if should_auto_create_rooms: + await self._create_and_join_rooms(user_id) + else: + await self._join_rooms(user_id) + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted @@ -380,42 +530,6 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - async def _generate_user_id(self): - if self._next_generated_user_id is None: - with await self._generate_user_id_linearizer.queue(()): - if self._next_generated_user_id is None: - self._next_generated_user_id = ( - await self.store.find_next_generated_user_id_localpart() - ) - - id = self._next_generated_user_id - self._next_generated_user_id += 1 - return str(id) - - async def _join_user_to_room(self, requester, room_identifier): - room_member_handler = self.hs.get_room_member_handler() - if RoomID.is_valid(room_identifier): - room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( - room_alias - ) - room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - - await room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - remote_room_hosts=remote_room_hosts, - action="join", - ratelimit=False, - ) - def check_registration_ratelimit(self, address): """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -443,6 +557,7 @@ class RegistrationHandler(BaseHandler): admin=False, user_type=None, address=None, + shadow_banned=False, ): """Register user in the datastore. @@ -460,9 +575,10 @@ class RegistrationHandler(BaseHandler): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the registration. + shadow_banned (bool): Whether to shadow-ban the user Returns: - Deferred + Awaitable """ if self.hs.config.worker_app: return self._register_client( @@ -475,6 +591,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) else: return self.store.register_user( @@ -486,6 +603,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + shadow_banned=shadow_banned, ) async def register_device( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 61db3ccc43..a29305f655 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,19 +20,28 @@ import itertools import logging import math +import random import string from collections import OrderedDict -from typing import Tuple - -from six import iteritems, string_types - -from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple + +from synapse.api.constants import ( + EventTypes, + JoinRules, + Membership, + RoomCreationPreset, + RoomEncryptionAlgorithms, +) from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( + JsonDict, + MutableStateMap, Requester, RoomAlias, RoomID, @@ -40,6 +49,7 @@ from synapse.types import ( StateMap, StreamToken, UserID, + create_requester, ) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer @@ -48,6 +58,9 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) id_server_scheme = "https://" @@ -56,32 +69,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000 class RoomCreationHandler(BaseHandler): - - PRESETS_DICT = { - RoomCreationPreset.PRIVATE_CHAT: { - "join_rules": JoinRules.INVITE, - "history_visibility": "shared", - "original_invitees_have_ops": False, - "guest_can_join": True, - "power_level_content_override": {"invite": 0}, - }, - RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { - "join_rules": JoinRules.INVITE, - "history_visibility": "shared", - "original_invitees_have_ops": True, - "guest_can_join": True, - "power_level_content_override": {"invite": 0}, - }, - RoomCreationPreset.PUBLIC_CHAT: { - "join_rules": JoinRules.PUBLIC, - "history_visibility": "shared", - "original_invitees_have_ops": False, - "guest_can_join": False, - "power_level_content_override": {}, - }, - } - - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(RoomCreationHandler, self).__init__(hs) self.spam_checker = hs.get_spam_checker() @@ -89,6 +77,39 @@ class RoomCreationHandler(BaseHandler): self.room_member_handler = hs.get_room_member_handler() self.config = hs.config + # Room state based off defined presets + self._presets_dict = { + RoomCreationPreset.PRIVATE_CHAT: { + "join_rules": JoinRules.INVITE, + "history_visibility": "shared", + "original_invitees_have_ops": False, + "guest_can_join": True, + "power_level_content_override": {"invite": 0}, + }, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { + "join_rules": JoinRules.INVITE, + "history_visibility": "shared", + "original_invitees_have_ops": True, + "guest_can_join": True, + "power_level_content_override": {"invite": 0}, + }, + RoomCreationPreset.PUBLIC_CHAT: { + "join_rules": JoinRules.PUBLIC, + "history_visibility": "shared", + "original_invitees_have_ops": False, + "guest_can_join": False, + "power_level_content_override": {}, + }, + } # type: Dict[str, Dict[str, Any]] + + # Modify presets to selectively enable encryption by default per homeserver config + for preset_name, preset_config in self._presets_dict.items(): + encrypted = ( + preset_name + in self.config.encryption_enabled_by_default_for_room_presets + ) + preset_config["encrypted"] = encrypted + self._replication = hs.get_replication_data_handler() # linearizer to stop two upgrades happening at once @@ -106,7 +127,7 @@ class RoomCreationHandler(BaseHandler): async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion - ): + ) -> str: """Replace a room with a new room with a different version Args: @@ -115,7 +136,10 @@ class RoomCreationHandler(BaseHandler): new_version: the new room version to use Returns: - Deferred[unicode]: the new room id + the new room id + + Raises: + ShadowBanError if the requester is shadow-banned. """ await self.ratelimit(requester) @@ -151,6 +175,15 @@ class RoomCreationHandler(BaseHandler): async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): + """ + Args: + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_versions: the version to upgrade the room to + + Raises: + ShadowBanError if the requester is shadow-banned. + """ user_id = requester.user.to_string() # start by allocating a new room id @@ -202,6 +235,9 @@ class RoomCreationHandler(BaseHandler): old_room_state = await tombstone_context.get_current_state_ids() + # We know the tombstone event isn't an outlier so it has current state. + assert old_room_state is not None + # update any aliases await self._move_aliases_to_new_room( requester, old_room_id, new_room_id, old_room_state @@ -226,7 +262,7 @@ class RoomCreationHandler(BaseHandler): old_room_id: str, new_room_id: str, old_room_state: StateMap[str], - ): + ) -> None: """Send updated power levels in both rooms after an upgrade Args: @@ -235,8 +271,8 @@ class RoomCreationHandler(BaseHandler): new_room_id: the id of the replacement room old_room_state: the state map for the old room - Returns: - Deferred + Raises: + ShadowBanError if the requester is shadow-banned. """ old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, "")) @@ -309,7 +345,7 @@ class RoomCreationHandler(BaseHandler): new_room_id: str, new_room_version: RoomVersion, tombstone_event_id: str, - ): + ) -> None: """Populate a new room based on an old room Args: @@ -319,8 +355,6 @@ class RoomCreationHandler(BaseHandler): created with _gemerate_room_id()) new_room_version: the new room version to use tombstone_event_id: the ID of the tombstone event in the old room. - Returns: - Deferred """ user_id = requester.user.to_string() @@ -364,7 +398,7 @@ class RoomCreationHandler(BaseHandler): # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) - for k, old_event_id in iteritems(old_room_state_ids): + for k, old_event_id in old_room_state_ids.items(): old_event = old_room_state_events.get(old_event_id) if old_event: initial_state[k] = old_event.content @@ -417,7 +451,7 @@ class RoomCreationHandler(BaseHandler): old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) - for k, old_event in iteritems(old_room_member_state_events): + for old_event in old_room_member_state_events.values(): # Only transfer ban events if ( "membership" in old_event.content @@ -520,17 +554,21 @@ class RoomCreationHandler(BaseHandler): logger.error("Unable to send updated alias events in new room: %s", e) async def create_room( - self, requester, config, ratelimit=True, creator_join_profile=None + self, + requester: Requester, + config: JsonDict, + ratelimit: bool = True, + creator_join_profile: Optional[JsonDict] = None, ) -> Tuple[dict, int]: """ Creates a new room. Args: - requester (synapse.types.Requester): + requester: The user who requested the room creation. - config (dict) : A dict of configuration options. - ratelimit (bool): set to False to disable the rate limiter + config : A dict of configuration options. + ratelimit: set to False to disable the rate limiter - creator_join_profile (dict|None): + creator_join_profile: Set to override the displayname and avatar for the creating user in this room. If unset, displayname and avatar will be derived from the user's profile. If set, should contain the @@ -582,7 +620,7 @@ class RoomCreationHandler(BaseHandler): "room_version", self.config.default_room_version.identifier ) - if not isinstance(room_version_id, string_types): + if not isinstance(room_version_id, str): raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) @@ -593,6 +631,7 @@ class RoomCreationHandler(BaseHandler): Codes.UNSUPPORTED_ROOM_VERSION, ) + room_alias = None if "room_alias_name" in config: for wchar in string.whitespace: if wchar in config["room_alias_name"]: @@ -603,9 +642,8 @@ class RoomCreationHandler(BaseHandler): if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) - else: - room_alias = None + invite_3pid_list = config.get("invite_3pid", []) invite_list = config.get("invite", []) for i in invite_list: try: @@ -614,6 +652,14 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) + if (invite_list or invite_3pid_list) and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + + # Allow the request to go through, but remove any associated invites. + invite_3pid_list = [] + invite_list = [] + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") @@ -628,8 +674,6 @@ class RoomCreationHandler(BaseHandler): % (user_id,), ) - invite_3pid_list = config.get("invite_3pid", []) - visibility = config.get("visibility", None) is_public = visibility == "public" @@ -724,6 +768,8 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct + # Note that update_membership with an action of "invite" can raise a + # ShadowBanError, but this was handled above by emptying invite_list. _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), @@ -738,6 +784,8 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] + # Note that do_3pid_invite can raise a ShadowBanError, but this was + # handled above by emptying invite_3pid_list. last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, @@ -763,23 +811,30 @@ class RoomCreationHandler(BaseHandler): async def _send_events_for_new_room( self, - creator, # A Requester object. - room_id, - preset_config, - invite_list, - initial_state, - creation_content, - room_alias=None, - power_level_content_override=None, # Doesn't apply when initial state has power level state event content - creator_join_profile=None, + creator: Requester, + room_id: str, + preset_config: str, + invite_list: List[str], + initial_state: MutableStateMap, + creation_content: JsonDict, + room_alias: Optional[RoomAlias] = None, + power_level_content_override: Optional[JsonDict] = None, + creator_join_profile: Optional[JsonDict] = None, ) -> int: """Sends the initial events into a new room. + `power_level_content_override` doesn't apply when initial state has + power level state event content. + Returns: The stream_id of the last event persisted. """ - def create(etype, content, **kwargs): + creator_id = creator.user.to_string() + + event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} + + def create(etype: str, content: JsonDict, **kwargs) -> JsonDict: e = {"type": etype, "content": content} e.update(event_keys) @@ -787,22 +842,20 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype, content, **kwargs) -> int: + async def send(etype: str, content: JsonDict, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) + # Allow these events to be sent even if the user is shadow-banned to + # allow the room creation to complete. ( _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False + creator, event, ratelimit=False, ignore_shadow_ban=True, ) return last_stream_id - config = RoomCreationHandler.PRESETS_DICT[preset_config] - - creator_id = creator.user.to_string() - - event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} + config = self._presets_dict[preset_config] creation_content.update({"creator": creator_id}) await send(etype=EventTypes.Create, content=creation_content) @@ -844,7 +897,7 @@ class RoomCreationHandler(BaseHandler): "kick": 50, "redact": 50, "invite": 50, - } + } # type: JsonDict if config["original_invitees_have_ops"]: for invitee in invite_list: @@ -888,10 +941,17 @@ class RoomCreationHandler(BaseHandler): etype=etype, state_key=state_key, content=content ) + if config["encrypted"]: + last_sent_stream_id = await send( + etype=EventTypes.RoomEncryption, + state_key="", + content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, + ) + return last_sent_stream_id async def _generate_room_id( - self, creator_id: str, is_public: str, room_version: RoomVersion, + self, creator_id: str, is_public: bool, room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -914,24 +974,31 @@ class RoomCreationHandler(BaseHandler): raise StoreError(500, "Couldn't generate a room ID.") -class RoomContextHandler(object): - def __init__(self, hs): +class RoomContextHandler: + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state - async def get_event_context(self, user, room_id, event_id, limit, event_filter): + async def get_event_context( + self, + user: UserID, + room_id: str, + event_id: str, + limit: int, + event_filter: Optional[Filter], + ) -> Optional[JsonDict]: """Retrieves events, pagination tokens and state around a given event in a room. Args: - user (UserID) - room_id (str) - event_id (str) - limit (int): The maximum number of events to return in total + user + room_id + event_id + limit: The maximum number of events to return in total (excluding state). - event_filter (Filter|None): the filter to apply to the events returned + event_filter: the filter to apply to the events returned (excluding the target event_id) Returns: @@ -1017,16 +1084,22 @@ class RoomContextHandler(object): return results -class RoomEventSource(object): - def __init__(self, hs): +class RoomEventSource: + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() async def get_new_events( - self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None - ): + self, + user: UserID, + from_key: str, + limit: int, + room_ids: List[str], + is_guest: bool, + explicit_room_id: Optional[str] = None, + ) -> Tuple[List[EventBase], str]: # We just ignore the key for now. - to_key = await self.get_current_key() + to_key = self.get_current_key() from_token = RoomStreamToken.parse(from_key) if from_token.topological: @@ -1066,8 +1139,208 @@ class RoomEventSource(object): return (events, end_key) - def get_current_key(self): - return self.store.get_room_events_max_id() + def get_current_key(self) -> str: + return "s%d" % (self.store.get_room_max_stream_ordering(),) - def get_current_key_for_room(self, room_id): + def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) + + +class RoomShutdownHandler: + + DEFAULT_MESSAGE = ( + "Sharing illegal content on this server is not permitted and rooms in" + " violation will be blocked." + ) + DEFAULT_ROOM_NAME = "Content Violation Notification" + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.room_member_handler = hs.get_room_member_handler() + self._room_creation_handler = hs.get_room_creation_handler() + self._replication = hs.get_replication_data_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state = hs.get_state_handler() + self.store = hs.get_datastore() + + async def shutdown_room( + self, + room_id: str, + requester_user_id: str, + new_room_user_id: Optional[str] = None, + new_room_name: Optional[str] = None, + message: Optional[str] = None, + block: bool = False, + ) -> dict: + """ + Shuts down a room. Moves all local users and room aliases automatically + to a new room if `new_room_user_id` is set. Otherwise local users only + leave the room without any information. + + The new room will be created with the user specified by the + `new_room_user_id` parameter as room administrator and will contain a + message explaining what happened. Users invited to the new room will + have power level `-10` by default, and thus be unable to speak. + + The local server will only have the power to move local user and room + aliases to the new room. Users on other servers will be unaffected. + + Args: + room_id: The ID of the room to shut down. + requester_user_id: + User who requested the action and put the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + + Returns: a dict containing the following keys: + kicked_users: An array of users (`user_id`) that were kicked. + failed_to_kick_users: + An array of users (`user_id`) that that were not kicked. + local_aliases: + An array of strings representing the local aliases that were + migrated from the old room to the new. + new_room_id: A string representing the room ID of the new room. + """ + + if not new_room_name: + new_room_name = self.DEFAULT_ROOM_NAME + if not message: + message = self.DEFAULT_MESSAGE + + if not RoomID.is_valid(room_id): + raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + + if not await self.store.get_room(room_id): + raise NotFoundError("Unknown room id %s" % (room_id,)) + + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. + if block: + await self.store.block_room(room_id, requester_user_id) + + if new_room_user_id is not None: + if not self.hs.is_mine_id(new_room_user_id): + raise SynapseError( + 400, "User must be our own: %s" % (new_room_user_id,) + ) + + room_creator_requester = create_requester(new_room_user_id) + + info, stream_id = await self._room_creation_handler.create_room( + room_creator_requester, + config={ + "preset": RoomCreationPreset.PUBLIC_CHAT, + "name": new_room_name, + "power_level_content_override": {"users_default": -10}, + }, + ratelimit=False, + ) + new_room_id = info["room_id"] + + logger.info( + "Shutting down room %r, joining to new room: %r", room_id, new_room_id + ) + + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + else: + new_room_id = None + logger.info("Shutting down room %r", room_id) + + users = await self.state.get_current_users_in_room(room_id) + kicked_users = [] + failed_to_kick_users = [] + for user_id in users: + if not self.hs.is_mine_id(user_id): + continue + + logger.info("Kicking %r from %r...", user_id, room_id) + + try: + # Kick users from room + target_requester = create_requester(user_id) + _, stream_id = await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=room_id, + action=Membership.LEAVE, + content={}, + ratelimit=False, + require_consent=False, + ) + + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + + await self.room_member_handler.forget(target_requester.user, room_id) + + # Join users to new room + if new_room_user_id: + await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=new_room_id, + action=Membership.JOIN, + content={}, + ratelimit=False, + require_consent=False, + ) + + kicked_users.append(user_id) + except Exception: + logger.exception( + "Failed to leave old room and join new room for %r", user_id + ) + failed_to_kick_users.append(user_id) + + # Send message in new room and move aliases + if new_room_user_id: + await self.event_creation_handler.create_and_send_nonmember_event( + room_creator_requester, + { + "type": "m.room.message", + "content": {"body": message, "msgtype": "m.text"}, + "room_id": new_room_id, + "sender": new_room_user_id, + }, + ratelimit=False, + ) + + aliases_for_room = await self.store.get_aliases_for_room(room_id) + + await self.store.update_aliases_for_room( + room_id, new_room_id, requester_user_id + ) + else: + aliases_for_room = [] + + return { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + } diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 4cbc02b0d0..5dd7b28391 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -17,17 +17,13 @@ import logging from collections import namedtuple from typing import Any, Dict, Optional -from six import iteritems - import msgpack from unpaddedbase64 import decode_base64, encode_base64 -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, HttpResponseException from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler @@ -49,7 +45,7 @@ class RoomListHandler(BaseHandler): hs, "remote_room_list", timeout_ms=30 * 1000 ) - def get_local_public_room_list( + async def get_local_public_room_list( self, limit=None, since_token=None, @@ -74,7 +70,7 @@ class RoomListHandler(BaseHandler): API """ if not self.enable_room_list_search: - return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) + return {"chunk": [], "total_room_count_estimate": 0} logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", @@ -89,7 +85,7 @@ class RoomListHandler(BaseHandler): # appservice specific lists. logger.info("Bypassing cache as search request.") - return self._get_public_room_list( + return await self._get_public_room_list( limit, since_token, search_filter, @@ -98,7 +94,7 @@ class RoomListHandler(BaseHandler): ) key = (limit, since_token, network_tuple) - return self.response_cache.wrap( + return await self.response_cache.wrap( key, self._get_public_room_list, limit, @@ -107,8 +103,7 @@ class RoomListHandler(BaseHandler): from_federation=from_federation, ) - @defer.inlineCallbacks - def _get_public_room_list( + async def _get_public_room_list( self, limit: Optional[int] = None, since_token: Optional[str] = None, @@ -147,7 +142,7 @@ class RoomListHandler(BaseHandler): # we request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None - results = yield self.store.get_largest_public_rooms( + results = await self.store.get_largest_public_rooms( network_tuple, search_filter, probing_limit, @@ -223,44 +218,44 @@ class RoomListHandler(BaseHandler): response["chunk"] = results - response["total_room_count_estimate"] = yield self.store.count_public_rooms( + response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, ignore_non_federatable=from_federation ) return response - @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry( + @cached(num_args=1, cache_context=True) + async def generate_room_entry( self, - room_id, - num_joined_users, + room_id: str, + num_joined_users: int, cache_context, - with_alias=True, - allow_private=False, - ): + with_alias: bool = True, + allow_private: bool = False, + ) -> Optional[dict]: """Returns the entry for a room Args: - room_id (str): The room's ID. - num_joined_users (int): Number of users in the room. + room_id: The room's ID. + num_joined_users: Number of users in the room. cache_context: Information for cached responses. - with_alias (bool): Whether to return the room's aliases in the result. - allow_private (bool): Whether invite-only rooms should be shown. + with_alias: Whether to return the room's aliases in the result. + allow_private: Whether invite-only rooms should be shown. Returns: - Deferred[dict|None]: Returns a room entry as a dictionary, or None if this + Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ result = {"room_id": room_id, "num_joined_members": num_joined_users} if with_alias: - aliases = yield self.store.get_aliases_for_room( + aliases = await self.store.get_aliases_for_room( room_id, on_invalidate=cache_context.invalidate ) if aliases: result["aliases"] = aliases - current_state_ids = yield self.store.get_current_state_ids( + current_state_ids = await self.store.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) @@ -268,10 +263,10 @@ class RoomListHandler(BaseHandler): # We're not in the room, so may as well bail out here. return result - event_map = yield self.store.get_events( + event_map = await self.store.get_events( [ event_id - for key, event_id in iteritems(current_state_ids) + for key, event_id in current_state_ids.items() if key[0] in ( EventTypes.Create, @@ -338,8 +333,7 @@ class RoomListHandler(BaseHandler): return result - @defer.inlineCallbacks - def get_remote_public_room_list( + async def get_remote_public_room_list( self, server_name, limit=None, @@ -358,7 +352,7 @@ class RoomListHandler(BaseHandler): # to a locally-filtered search if we must. try: - res = yield self._get_remote_list_cached( + res = await self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, @@ -383,7 +377,7 @@ class RoomListHandler(BaseHandler): limit = None since_token = None - res = yield self._get_remote_list_cached( + res = await self._get_remote_list_cached( server_name, limit=limit, since_token=since_token, @@ -402,7 +396,7 @@ class RoomListHandler(BaseHandler): return res - def _get_remote_list_cached( + async def _get_remote_list_cached( self, server_name, limit=None, @@ -414,7 +408,7 @@ class RoomListHandler(BaseHandler): repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - return repl_layer.get_public_rooms( + return await repl_layer.get_public_rooms( server_name, limit=limit, since_token=since_token, @@ -430,7 +424,7 @@ class RoomListHandler(BaseHandler): include_all_networks, third_party_instance_id, ) - return self.remote_response_cache.wrap( + return await self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, server_name, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0f7af982f0..32b7e323fa 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2016-2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,28 +15,43 @@ import abc import logging -from typing import Dict, Iterable, List, Optional, Tuple +import random +from http import HTTPStatus +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union -from six.moves import http_client +from unpaddedbase64 import encode_base64 from synapse import types -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + ShadowBanError, + SynapseError, +) +from synapse.api.ratelimiting import Ratelimiter +from synapse.api.room_versions import EventFormatVersions +from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase +from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext -from synapse.replication.http.membership import ( - ReplicationLocallyRejectInviteRestServlet, -) -from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID +from synapse.events.validator import EventValidator +from synapse.storage.roommember import RoomsForUser +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) -class RoomMemberHandler(object): +class RoomMemberHandler: # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns @@ -46,7 +59,7 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -75,10 +88,17 @@ class RoomMemberHandler(object): ) if self._is_on_event_persistence_instance: self.persist_event_storage = hs.get_storage().persistence - else: - self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client( - hs - ) + + self._join_rate_limiter_local = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, + burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + ) + self._join_rate_limiter_remote = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, + burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + ) # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as @@ -106,46 +126,28 @@ class RoomMemberHandler(object): raise NotImplementedError() @abc.abstractmethod - async def _remote_reject_invite( + async def remote_reject_invite( self, + invite_event_id: str, + txn_id: Optional[str], requester: Requester, - remote_room_hosts: List[str], - room_id: str, - target: UserID, - content: dict, - ) -> Tuple[Optional[str], int]: - """Attempt to reject an invite for a room this server is not in. If we - fail to do so we locally mark the invite as rejected. + content: JsonDict, + ) -> Tuple[str, int]: + """ + Rejects an out-of-band invite we have received from a remote server Args: - requester - remote_room_hosts: List of servers to use to try and reject invite - room_id - target: The user rejecting the invite - content: The content for the rejection event + invite_event_id: ID of the invite to be rejected + txn_id: optional transaction ID supplied by the client + requester: user making the rejection request, according to the access token + content: additional content to include in the rejection event. + Normally an empty dict. Returns: - A dictionary to be returned to the client, may - include event_id etc, or nothing if we locally rejected + event id, stream_id of the leave event """ raise NotImplementedError() - async def locally_reject_invite(self, user_id: str, room_id: str) -> int: - """Mark the invite has having been rejected even though we failed to - create a leave event for it. - """ - if self._is_on_event_persistence_instance: - return await self.persist_event_storage.locally_reject_invite( - user_id, room_id - ) - else: - result = await self._locally_reject_client( - instance_name=self._event_stream_writer_instance, - user_id=user_id, - room_id=room_id, - ) - return result["stream_id"] - @abc.abstractmethod async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the @@ -174,7 +176,7 @@ class RoomMemberHandler(object): target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, @@ -215,24 +217,40 @@ class RoomMemberHandler(object): _, stream_id = await self.store.get_event_ordering(duplicate.event_id) return duplicate.event_id, stream_id - stream_id = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit - ) - prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + newly_joined = False if event.membership == Membership.JOIN: - # Only fire user_joined_room if the user has actually joined the - # room. Don't bother if the user is just changing their profile - # info. newly_joined = True if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN + + # Only rate-limit if the user actually joined the room, otherwise we'll end + # up blocking profile updates. if newly_joined: - await self._user_joined_room(target, room_id) + time_now_s = self.clock.time() + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + + stream_id = await self.event_creation_handler.handle_new_client_event( + requester, event, context, extra_users=[target], ratelimit=ratelimit, + ) + + if event.membership == Membership.JOIN and newly_joined: + # Only fire user_joined_room if the user has actually joined the + # room. Don't bother if the user is just changing their profile + # info. + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) @@ -289,7 +307,32 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: + """Update a user's membership in a room. + + Params: + requester: The user who is performing the update. + target: The user whose membership is being updated. + room_id: The room ID whose membership is being updated. + action: The membership change, see synapse.api.constants.Membership. + txn_id: The transaction ID, if given. + remote_room_hosts: Remote servers to send the update to. + third_party_signed: Information from a 3PID invite. + ratelimit: Whether to rate limit the request. + content: The content of the created event. + require_consent: Whether consent is required. + + Returns: + A tuple of the new event ID and stream ID. + + Raises: + ShadowBanError if a shadow-banned requester attempts to send an invite. + """ + if action == Membership.INVITE and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -320,7 +363,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Tuple[Optional[str], int]: + ) -> Tuple[str, int]: content_specified = bool(content) if content is None: content = {} @@ -329,7 +372,7 @@ class RoomMemberHandler(object): # later on. content = dict(content) - if not self.allow_per_room_profiles: + if not self.allow_per_room_profiles or requester.shadow_banned: # Strip profile data, knowing that new profile data will be added to the # event's content in event_creation_handler.create_event() using the target's # global profile. @@ -361,7 +404,7 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -444,7 +487,7 @@ class RoomMemberHandler(object): is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( - http_client.FORBIDDEN, + HTTPStatus.FORBIDDEN, "You cannot reject this invite", errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, ) @@ -463,6 +506,17 @@ class RoomMemberHandler(object): raise AuthError(403, "Guest access not allowed") if not is_host_in_room: + time_now_s = self.clock.time() + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -486,24 +540,43 @@ class RoomMemberHandler(object): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = await self._get_inviter(target.to_string(), room_id) - if not inviter: + invite = await self.store.get_invite_for_local_user_in_room( + user_id=target.to_string(), room_id=room_id + ) # type: Optional[RoomsForUser] + if not invite: + logger.info( + "%s sent a leave request to %s, but that is not an active room " + "on this server, and there is no pending invite", + target, + room_id, + ) + raise SynapseError(404, "Not a known room") - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - return await self._remote_reject_invite( - requester, remote_room_hosts, room_id, target, content, + logger.info( + "%s rejects invite to %s from %s", target, room_id, invite.sender + ) + + if not self.hs.is_mine_id(invite.sender): + # send the rejection to the inviter's HS (with fallback to + # local event) + return await self.remote_reject_invite( + invite.event_id, txn_id, requester, content, ) + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath, which will also send the + # rejection out to any other servers we believe are still in the room. + + # thanks to overzealous cleaning up of event_forward_extremities in + # `delete_old_current_state_events`, it's possible to end up with no + # forward extremities here. If that happens, let's just hang the + # rejection off the invite event. + # + # see: https://github.com/matrix-org/synapse/issues/7139 + if len(latest_event_ids) == 0: + latest_event_ids = [invite.event_id] + return await self._local_membership_update( requester=requester, target=target, @@ -669,9 +742,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -681,7 +752,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content @@ -738,6 +809,25 @@ class RoomMemberHandler(object): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: + """Invite a 3PID to a room. + + Args: + room_id: The room to invite the 3PID to. + inviter: The user sending the invite. + medium: The 3PID's medium. + address: The 3PID's address. + id_server: The identity server to use. + requester: The user making the request. + txn_id: The transaction ID this is part of, or None if this is not + part of a transaction. + id_access_token: The optional identity server access token. + + Returns: + The new stream ID. + + Raises: + ShadowBanError if the requester has been shadow-banned. + """ if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -745,6 +835,11 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. await self.base_handler.ratelimit(requester) @@ -769,6 +864,8 @@ class RoomMemberHandler(object): ) if invitee: + # Note that update_membership with an action of "invite" can raise + # a ShadowBanError, but this was done above already. _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) @@ -874,9 +971,7 @@ class RoomMemberHandler(object): ) return stream_id - async def _is_host_in_room( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) @@ -967,7 +1062,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - if self.hs.config.limit_remote_rooms.enabled: + check_complexity = self.hs.config.limit_remote_rooms.enabled + if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = not await self.auth.is_server_admin(user) + + if check_complexity: # Fetch the room complexity too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts @@ -990,7 +1089,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. - if self.hs.config.limit_remote_rooms.enabled: + if check_complexity: if too_complex is False: # We checked, and we're under the limit. return event_id, stream_id @@ -1003,7 +1102,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id # The room is too large. Leave. - requester = types.create_requester(user, None, False, None) + requester = types.create_requester(user, None, False, False, None) await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) @@ -1015,33 +1114,119 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id - async def _remote_reject_invite( + async def remote_reject_invite( self, + invite_event_id: str, + txn_id: Optional[str], requester: Requester, - remote_room_hosts: List[str], - room_id: str, - target: UserID, - content: dict, - ) -> Tuple[Optional[str], int]: - """Implements RoomMemberHandler._remote_reject_invite + content: JsonDict, + ) -> Tuple[str, int]: """ + Rejects an out-of-band invite received from a remote user + + Implements RoomMemberHandler.remote_reject_invite + """ + invite_event = await self.store.get_event(invite_event_id) + room_id = invite_event.room_id + target_user = invite_event.state_key + + # first of all, try doing a rejection via the inviting server fed_handler = self.federation_handler try: + inviter_id = UserID.from_string(invite_event.sender) event, stream_id = await fed_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, target.to_string(), content=content, + [inviter_id.domain], room_id, target_user, content=content ) return event.event_id, stream_id except Exception as e: - # if we were unable to reject the exception, just mark - # it as rejected on our end and plough ahead. + # if we were unable to reject the invite, we will generate our own + # leave event. # # The 'except' clause is very broad, but we need to # capture everything from DNS failures upwards # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.locally_reject_invite(target.to_string(), room_id) - return None, stream_id + return await self._locally_reject_invite( + invite_event, txn_id, requester, content + ) + + async def _locally_reject_invite( + self, + invite_event: EventBase, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """Generate a local invite rejection + + This is called after we fail to reject an invite via a remote server. It + generates an out-of-band membership event locally. + + Args: + invite_event: the invite to be rejected + txn_id: optional transaction ID supplied by the client + requester: user making the rejection request, according to the access token + content: additional content to include in the rejection event. + Normally an empty dict. + """ + + room_id = invite_event.room_id + target_user = invite_event.state_key + room_version = await self.store.get_room_version(room_id) + + content["membership"] = Membership.LEAVE + + # the auth events for the new event are the same as that of the invite, plus + # the invite itself. + # + # the prev_events are just the invite. + invite_hash = invite_event.event_id # type: Union[str, Tuple] + if room_version.event_format == EventFormatVersions.V1: + alg, h = compute_event_reference_hash(invite_event) + invite_hash = (invite_event.event_id, {alg: encode_base64(h)}) + + auth_events = tuple(invite_event.auth_events) + (invite_hash,) + prev_events = (invite_hash,) + + # we cap depth of generated events, to ensure that they are not + # rejected by other servers (and so that they can be persisted in + # the db) + depth = min(invite_event.depth + 1, MAX_DEPTH) + + event_dict = { + "depth": depth, + "auth_events": auth_events, + "prev_events": prev_events, + "type": EventTypes.Member, + "room_id": room_id, + "sender": target_user, + "content": content, + "state_key": target_user, + } + + event = create_local_event_from_event_dict( + clock=self.clock, + hostname=self.hs.hostname, + signing_key=self.hs.signing_key, + room_version=room_version, + event_dict=event_dict, + ) + event.internal_metadata.outlier = True + event.internal_metadata.out_of_band_membership = True + if txn_id is not None: + event.internal_metadata.txn_id = txn_id + if requester.access_token_id is not None: + event.internal_metadata.token_id = requester.access_token_id + + EventValidator().validate_new(event, self.config) + + context = await self.state_handler.compute_event_context(event) + context.app_service = requester.app_service + stream_id = await self.event_creation_handler.handle_new_client_event( + requester, event, context, extra_users=[UserID.from_string(target_user)], + ) + return event.event_id, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 02e0c4103d..897338fd54 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -61,21 +61,22 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret["event_id"], ret["stream_id"] - async def _remote_reject_invite( + async def remote_reject_invite( self, + invite_event_id: str, + txn_id: Optional[str], requester: Requester, - remote_room_hosts: List[str], - room_id: str, - target: UserID, content: dict, - ) -> Tuple[Optional[str], int]: - """Implements RoomMemberHandler._remote_reject_invite + ) -> Tuple[str, int]: + """ + Rejects an out-of-band invite received from a remote user + + Implements RoomMemberHandler.remote_reject_invite """ ret = await self._remote_reject_client( + invite_event_id=invite_event_id, + txn_id=txn_id, requester=requester, - remote_room_hosts=remote_room_hosts, - room_id=room_id, - user_id=target.to_string(), content=content, ) return ret["event_id"], ret["stream_id"] diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index abecaa8313..66b063f991 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,15 +14,16 @@ # limitations under the License. import logging import re -from typing import Callable, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple import attr import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import SynapseError +from synapse.api.errors import AuthError, SynapseError from synapse.config import ConfigError +from synapse.config.saml2_config import SamlAttributeRequirement from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -34,6 +35,9 @@ from synapse.types import ( from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) @@ -49,7 +53,8 @@ class Saml2SessionData: class SamlHandler: - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -62,6 +67,7 @@ class SamlHandler: self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) + self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements # plugin to do custom mapping from saml response to mxid self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( @@ -73,7 +79,7 @@ class SamlHandler: self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict = {} + self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) @@ -96,6 +102,9 @@ class SamlHandler: relay_state=client_redirect_url ) + # Since SAML sessions timeout it is useful to log when they were created. + logger.info("Initiating a new SAML session: %s" % (reqid,)) + now = self._clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, @@ -125,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -139,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -147,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -162,11 +183,18 @@ class SamlHandler: saml2.BINDING_HTTP_POST, outstanding=self._outstanding_requests_dict, ) + except saml2.response.UnsolicitedResponse as e: + # the pysaml2 library helpfully logs an ERROR here, but neglects to log + # the session ID. I don't really want to put the full text of the exception + # in the (user-visible) exception message, so let's log the exception here + # so we can track down the session IDs later. + logger.warning(str(e)) + raise SynapseError(400, "Unexpected SAML2 login.") except Exception as e: - raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) + raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,)) if saml2_auth.not_signed: - raise SynapseError(400, "SAML2 response was not signed") + raise SynapseError(400, "SAML2 response was not signed.") logger.debug("SAML2 response: %s", saml2_auth.origxml) for assertion in saml2_auth.assertions: @@ -185,6 +213,9 @@ class SamlHandler: saml2_auth.in_response_to, None ) + for requirement in self._saml2_attribute_requirements: + _check_attribute_requirement(saml2_auth.ava, requirement) + remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) @@ -273,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( @@ -291,6 +323,21 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): + values = ava.get(req.attribute, []) + for v in values: + if v == req.value: + return + + logger.info( + "SAML2 attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + raise AuthError(403, "You are not authorized to log in here.") + + DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) @@ -313,12 +360,12 @@ MXID_MAPPER_MAP = { @attr.s -class SamlConfig(object): +class SamlConfig: mxid_source_attribute = attr.ib() mxid_mapper = attr.ib() -class DefaultSamlMappingProvider(object): +class DefaultSamlMappingProvider: __version__ = "0.0.1" def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 4d40d3ac9c..d58f9788c5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -15,6 +15,7 @@ import itertools import logging +from typing import Iterable from unpaddedbase64 import decode_base64, encode_base64 @@ -37,7 +38,7 @@ class SearchHandler(BaseHandler): self.state_store = self.storage.state self.auth = hs.get_auth() - async def get_old_rooms_from_upgraded_room(self, room_id): + async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: """Retrieves room IDs of old rooms in the history of an upgraded room. We do so by checking the m.room.create event of the room for a @@ -48,10 +49,10 @@ class SearchHandler(BaseHandler): The full list of all found rooms in then returned. Args: - room_id (str): id of the room to search through. + room_id: id of the room to search through. Returns: - Deferred[iterable[str]]: predecessor room ids + Predecessor room ids """ historical_room_ids = [] @@ -339,7 +340,7 @@ class SearchHandler(BaseHandler): # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() contexts = {} for event in allowed_events: diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index 8590c1eff4..7a4ae0727a 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -class StateDeltasHandler(object): +class StateDeltasHandler: def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 149f861239..249ffe2a55 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -232,7 +232,7 @@ class StatsHandler: if membership == prev_membership: pass # noop - if membership == Membership.JOIN: + elif membership == Membership.JOIN: room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: room_stats_delta["invited_members"] += 1 diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6bdb24baff..e2ddb628ff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,9 +16,7 @@ import itertools import logging -from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple - -from six import iteritems, itervalues +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -33,6 +31,7 @@ from synapse.storage.state import StateFilter from synapse.types import ( Collection, JsonDict, + MutableStateMap, RoomStreamToken, StateMap, StreamToken, @@ -45,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache from synapse.util.metrics import Measure, measure_func from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Debug logger for https://github.com/matrix-org/synapse/issues/4422 @@ -96,7 +98,12 @@ class TimelineBatch: __bool__ = __nonzero__ # python3 -@attr.s(slots=True, frozen=True) +# We can't freeze this class, because we need to update it after it's instantiated to +# update its unread count. This is because we calculate the unread count for a room only +# if there are updates for it, which we check after the instance has been created. +# This should not be a big deal because we update the notification counts afterwards as +# well anyway. +@attr.s(slots=True) class JoinedSyncResult: room_id = attr.ib(type=str) timeline = attr.ib(type=TimelineBatch) @@ -105,6 +112,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -238,8 +246,8 @@ class SyncResult: __bool__ = __nonzero__ # python3 -class SyncHandler(object): - def __init__(self, hs): +class SyncHandler: + def __init__(self, hs: "HomeServer"): self.hs_config = hs.config self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -285,6 +293,7 @@ class SyncHandler(object): timeout, full_state, ) + logger.debug("Returning sync response for %s", user_id) return res async def _wait_for_sync_for_user( @@ -390,7 +399,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} + event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -408,7 +417,7 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} + event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) return now_token, ephemeral_by_room @@ -422,10 +431,6 @@ class SyncHandler(object): potential_recents: Optional[List[EventBase]] = None, newly_joined_room: bool = False, ) -> TimelineBatch: - """ - Returns: - a Deferred TimelineBatch - """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() block_all_timeline = ( @@ -454,7 +459,7 @@ class SyncHandler(object): current_state_ids_map = await self.state.get_current_state_ids( room_id ) - current_state_ids = frozenset(itervalues(current_state_ids_map)) + current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( self.storage, @@ -509,7 +514,7 @@ class SyncHandler(object): current_state_ids_map = await self.state.get_current_state_ids( room_id ) - current_state_ids = frozenset(itervalues(current_state_ids_map)) + current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( self.storage, @@ -593,7 +598,7 @@ class SyncHandler(object): room_id: str, sync_config: SyncConfig, batch: TimelineBatch, - state: StateMap[EventBase], + state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number @@ -715,9 +720,8 @@ class SyncHandler(object): ] missing_hero_state = await self.store.get_events(missing_hero_event_ids) - missing_hero_state = missing_hero_state.values() - for s in missing_hero_state: + for s in missing_hero_state.values(): cache.set(s.state_key, s.event_id) state[(EventTypes.Member, s.state_key)] = s @@ -741,7 +745,7 @@ class SyncHandler(object): since_token: Optional[StreamToken], now_token: StreamToken, full_state: bool, - ) -> StateMap[EventBase]: + ) -> MutableStateMap[EventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. @@ -909,7 +913,7 @@ class SyncHandler(object): logger.debug("filtering state from %r...", state_ids) state_ids = { t: event_id - for t, event_id in iteritems(state_ids) + for t, event_id in state_ids.items() if cache.get(t[1]) != event_id } logger.debug("...to %r", state_ids) @@ -935,7 +939,7 @@ class SyncHandler(object): async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> Optional[Dict[str, str]]: + ) -> Dict[str, int]: with Measure(self.clock, "unread_notifs_for_room_id"): last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), @@ -943,15 +947,10 @@ class SyncHandler(object): receipt_type="m.read", ) - if last_unread_event_id: - notifs = await self.store.get_unread_event_push_actions_by_room_for_user( - room_id, sync_config.user.to_string(), last_unread_event_id - ) - return notifs - - # There is no new information in this period, so your notification - # count is whatever it was last time. - return None + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( + room_id, sync_config.user.to_string(), last_unread_event_id + ) + return notifs async def generate_sync_result( self, @@ -965,7 +964,7 @@ class SyncHandler(object): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = await self.event_sources.get_current_token() + now_token = self.event_sources.get_current_token() logger.debug( "Calculating sync response for %r between %s and %s", @@ -992,10 +991,14 @@ class SyncHandler(object): joined_room_ids=joined_room_ids, ) + logger.debug("Fetching account data") + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) + logger.debug("Fetching room data") + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) @@ -1006,10 +1009,12 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: + logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) + logger.debug("Fetching to-device data") await self._generate_sync_entry_for_to_device(sync_result_builder) device_lists = await self._generate_sync_entry_for_device_list( @@ -1020,6 +1025,7 @@ class SyncHandler(object): newly_left_users=newly_left_users, ) + logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict if device_id: @@ -1027,6 +1033,7 @@ class SyncHandler(object): user_id, device_id ) + logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 @@ -1037,6 +1044,7 @@ class SyncHandler(object): "Sync result for newly joined room %s: %r", room_id, joined_room ) + logger.debug("Sync response calculation complete") return SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -1409,8 +1417,9 @@ class SyncHandler(object): newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms - def handle_room_entries(room_entry): - return self._generate_room_entry( + async def handle_room_entries(room_entry): + logger.debug("Generating room entry for %s", room_entry.room_id) + res = await self._generate_room_entry( sync_result_builder, ignored_users, room_entry, @@ -1419,6 +1428,8 @@ class SyncHandler(object): account_data=account_data_by_room.get(room_entry.room_id, {}), always_include=sync_result_builder.full_state, ) + logger.debug("Generated room entry for %s", room_entry.room_id) + return res await concurrently_execute(handle_room_entries, room_entries, 10) @@ -1430,7 +1441,7 @@ class SyncHandler(object): if since_token: for joined_sync in sync_result_builder.joined: it = itertools.chain( - joined_sync.timeline.events, itervalues(joined_sync.state) + joined_sync.timeline.events, joined_sync.state.values() ) for event in it: if event.type == EventTypes.Member: @@ -1505,7 +1516,7 @@ class SyncHandler(object): newly_left_rooms = [] room_entries = [] invited = [] - for room_id, events in iteritems(mem_change_events_by_room_id): + for room_id, events in mem_change_events_by_room_id.items(): logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1762,7 +1773,7 @@ class SyncHandler(object): ignored_users: Set[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[List[JsonDict]], + tags: Optional[Dict[str, Dict[str, Any]]], account_data: Dict[str, JsonDict], always_include: bool = False, ): @@ -1878,7 +1889,7 @@ class SyncHandler(object): ) if room_builder.rtype == "joined": - unread_notifications = {} # type: Dict[str, str] + unread_notifications = {} # type: Dict[str, int] room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1887,14 +1898,16 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=0, ) if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + room_sync.unread_count = notifs["unread_count"] sync_result_builder.joined.append(room_sync) @@ -1993,17 +2006,17 @@ def _calculate_state( event_id_to_key = { e: key for key, e in itertools.chain( - iteritems(timeline_contains), - iteritems(previous), - iteritems(timeline_start), - iteritems(current), + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(itervalues(current)) - ts_ids = set(itervalues(timeline_start)) - p_ids = set(itervalues(previous)) - tc_ids = set(itervalues(timeline_contains)) + c_ids = set(current.values()) + ts_ids = set(timeline_start.values()) + p_ids = set(previous.values()) + tc_ids = set(timeline_contains.values()) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -2017,7 +2030,7 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member + e for t, e in timeline_start.items() if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids @@ -2062,7 +2075,7 @@ class SyncResultBuilder: @attr.s -class RoomSyncResultBuilder(object): +class RoomSyncResultBuilder: """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c7bc14c623..3cbfc2d780 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -14,18 +14,21 @@ # limitations under the License. import logging +import random from collections import namedtuple -from typing import List +from typing import TYPE_CHECKING, List, Set, Tuple -from twisted.internet import defer - -from synapse.api.errors import AuthError, SynapseError -from synapse.logging.context import run_in_background +from synapse.api.errors import AuthError, ShadowBanError, SynapseError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.streams import TypingStream from synapse.types import UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -41,48 +44,48 @@ FEDERATION_TIMEOUT = 60 * 1000 FEDERATION_PING_INTERVAL = 40 * 1000 -class TypingHandler(object): - def __init__(self, hs): +class FollowerTypingHandler: + """A typing handler on a different process than the writer that is updated + via replication. + """ + + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.server_name = hs.config.server_name - self.auth = hs.get_auth() - self.is_mine_id = hs.is_mine_id - self.notifier = hs.get_notifier() - self.state = hs.get_state_handler() - - self.hs = hs - self.clock = hs.get_clock() - self.wheel_timer = WheelTimer(bucket_size=5000) + self.is_mine_id = hs.is_mine_id - self.federation = hs.get_federation_sender() + self.federation = None + if hs.should_send_federation(): + self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + if hs.config.worker.writers.typing != hs.get_instance_name(): + hs.get_federation_registry().register_instance_for_edu( + "m.typing", hs.config.worker.writers.typing, + ) - hs.get_distributor().observe("user_left_room", self.user_left_room) + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} - self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} - + self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 - self._reset() - - # caches which room_ids changed at which serials - self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial - ) self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): - """ - Reset the typing handler's data caches. + """Reset the typing handler's data caches. """ # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing self._room_typing = {} + self._member_last_federation_poke = {} + self.wheel_timer = WheelTimer(bucket_size=5000) + def _handle_timeouts(self): logger.debug("Checking for typing timeouts") @@ -91,34 +94,143 @@ class TypingHandler(object): members = set(self.wheel_timer.fetch(now)) for member in members: - if not self.is_typing(member): - # Nothing to do if they're no longer typing - continue - - until = self._member_typing_until.get(member, None) - if not until or until <= now: - logger.info("Timing out typing for: %s", member.user_id) - self._stopped_typing(member) - continue - - # Check if we need to resend a keep alive over federation for this - # user. - if self.hs.is_mine_id(member.user_id): - last_fed_poke = self._member_last_federation_poke.get(member, None) - if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background(self._push_remote, member=member, typing=True) - - # Add a paranoia timer to ensure that we always have a timer for - # each person typing. - self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) + self._handle_timeout_for_member(now, member) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + # Check if we need to resend a keep alive over federation for this + # user. + if self.federation and self.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: + run_as_background_process( + "typing._push_remote", self._push_remote, member=member, typing=True + ) + + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) - @defer.inlineCallbacks - def started_typing(self, target_user, auth_user, room_id, timeout): + async def _push_remote(self, member, typing): + if not self.federation: + return + + try: + users = await self.store.get_users_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + ) + + for domain in {get_domain_from_id(u) for u in users}: + if domain != self.server_name: + logger.debug("sending typing update to %s", domain) + self.federation.build_and_send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") + + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + """Should be called whenever we receive updates for typing stream. + """ + + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. + self._latest_room_serial = token + + for row in rows: + self._room_serials[row.room_id] = token + + prev_typing = set(self._room_typing.get(row.room_id, [])) + now_typing = set(row.user_ids) + self._room_typing[row.room_id] = row.user_ids + + run_as_background_process( + "_handle_change_in_typing", + self._handle_change_in_typing, + row.room_id, + prev_typing, + now_typing, + ) + + async def _handle_change_in_typing( + self, room_id: str, prev_typing: Set[str], now_typing: Set[str] + ): + """Process a change in typing of a room from replication, sending EDUs + for any local users. + """ + for user_id in now_typing - prev_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), True) + + for user_id in prev_typing - now_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), False) + + def get_current_token(self): + return self._latest_room_serial + + +class TypingWriterHandler(FollowerTypingHandler): + def __init__(self, hs): + super().__init__(hs) + + assert hs.config.worker.writers.typing == hs.get_instance_name() + + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + + self.hs = hs + + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + + hs.get_distributor().observe("user_left_room", self.user_left_room) + + self._member_typing_until = {} # clock time we expect to stop + + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial + ) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + super()._handle_timeout_for_member(now, member) + + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + until = self._member_typing_until.get(member, None) + if not until or until <= now: + logger.info("Timing out typing for: %s", member.user_id) + self._stopped_typing(member) + return + + async def started_typing(self, target_user, requester, room_id, timeout): target_user_id = target_user.to_string() - auth_user_id = auth_user.to_string() + auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") @@ -126,7 +238,12 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_user_in_room(room_id, target_user_id) + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -145,10 +262,9 @@ class TypingHandler(object): self._push_update(member=member, typing=True) - @defer.inlineCallbacks - def stopped_typing(self, target_user, auth_user, room_id): + async def stopped_typing(self, target_user, requester, room_id): target_user_id = target_user.to_string() - auth_user_id = auth_user.to_string() + auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") @@ -156,7 +272,12 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_user_in_room(room_id, target_user_id) + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) @@ -164,12 +285,11 @@ class TypingHandler(object): self._stopped_typing(member) - @defer.inlineCallbacks def user_left_room(self, user, room_id): user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) - yield self._stopped_typing(member) + self._stopped_typing(member) def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): @@ -184,39 +304,13 @@ class TypingHandler(object): def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - run_in_background(self._push_remote, member, typing) - - self._push_update_local(member=member, typing=typing) - - @defer.inlineCallbacks - def _push_remote(self, member, typing): - try: - users = yield self.state.get_current_users_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() - - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + run_as_background_process( + "typing._push_remote", self._push_remote, member, typing ) - for domain in {get_domain_from_id(u) for u in users}: - if domain != self.server_name: - logger.debug("sending typing update to %s", domain) - self.federation.build_and_send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) - except Exception: - logger.exception("Error pushing typing notif to remotes") + self._push_update_local(member=member, typing=typing) - @defer.inlineCallbacks - def _recv_edu(self, origin, content): + async def _recv_edu(self, origin, content): room_id = content["room_id"] user_id = content["user_id"] @@ -231,7 +325,7 @@ class TypingHandler(object): ) return - users = yield self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: @@ -259,14 +353,31 @@ class TypingHandler(object): ) async def get_all_typing_updates( - self, last_id: int, current_id: int, limit: int - ) -> List[dict]: - """Get up to `limit` typing updates between the given tokens, earliest - updates first. + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for typing replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return [] + return [], current_id, False changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id @@ -280,15 +391,28 @@ class TypingHandler(object): serial = self._room_serials[room_id] if last_id < serial <= current_id: typing = self._room_typing[room_id] - rows.append((serial, room_id, list(typing))) + rows.append((serial, [room_id, list(typing)])) rows.sort() - return rows[:limit] - def get_current_token(self): - return self._latest_room_serial + limited = False + # We, unusually, use a strict limit here as we have all the rows in + # memory rather than pulling them out of the database with a `LIMIT ?` + # clause. + if len(rows) > limit: + rows = rows[:limit] + current_id = rows[-1][0] + limited = True + + return rows, current_id, limited + + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + # The writing process should never get updates from replication. + raise Exception("Typing writer instance got typing info over replication") -class TypingNotificationEventSource(object): +class TypingNotificationEventSource: def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() @@ -306,7 +430,7 @@ class TypingNotificationEventSource(object): "content": {"user_ids": list(typing)}, } - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -320,7 +444,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return defer.succeed((events, handler._latest_room_serial)) + return (events, handler._latest_room_serial) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8b24a73319..9146dc1a3b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from canonicaljson import json +import logging +from typing import Any -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError from synapse.config.emailconfig import ThreepidBehaviour +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -32,25 +32,25 @@ class UserInteractiveAuthChecker: def __init__(self, hs): pass - def is_enabled(self): + def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: - bool: True if this login type is enabled. + True if this login type is enabled. """ - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step Args: - authdict (dict): authentication dictionary from the client - clientip (str): The IP address of the client. + authdict: authentication dictionary from the client + clientip: The IP address of the client. Raises: SynapseError if authentication failed Returns: - Deferred: the result of authentication (to pass back to the client?) + The result of authentication (to pass back to the client?) """ raise NotImplementedError() @@ -61,8 +61,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class TermsAuthChecker(UserInteractiveAuthChecker): @@ -71,8 +71,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class RecaptchaAuthChecker(UserInteractiveAuthChecker): @@ -88,8 +88,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return self._enabled - @defer.inlineCallbacks - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict, clientip): try: user_response = authdict["response"] except KeyError: @@ -106,7 +105,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): # TODO: get this from the homeserver rather than creating a new one for # each request try: - resp_body = yield self._http_client.post_urlencoded_get_json( + resp_body = await self._http_client.post_urlencoded_get_json( self._url, args={ "secret": self._secret, @@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = json.loads(data) + resp_body = json_decoder.decode(data.decode("utf-8")) if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly @@ -218,8 +217,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec ThreepidBehaviour.LOCAL, ) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("email", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("email", authdict) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): @@ -232,8 +231,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): def is_enabled(self): return bool(self.hs.config.account_threepid_delegate_msisdn) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("msisdn", authdict) INTERACTIVE_AUTH_CHECKERS = [ diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 12423b909a..e21f8dbc58 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -15,8 +15,6 @@ import logging -from six import iteritems, iterkeys - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler @@ -236,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler): async def _handle_room_publicity_change( self, room_id, prev_event_id, event_id, typ ): - """Handle a room having potentially changed from/to world_readable/publically + """Handle a room having potentially changed from/to world_readable/publicly joinable. Args: @@ -289,7 +287,7 @@ class UserDirectoryHandler(StateDeltasHandler): users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. - for user_id in iterkeys(users_with_profile): + for user_id in users_with_profile.keys(): await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. @@ -298,7 +296,7 @@ class UserDirectoryHandler(StateDeltasHandler): # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. - for user_id, profile in iteritems(users_with_profile): + for user_id, profile in users_with_profile.items(): await self._handle_new_user(room_id, user_id, profile) async def _handle_new_user(self, room_id, user_id, profile): @@ -390,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler): prev_name = prev_event.content.get("displayname") new_name = event.content.get("displayname") + # If the new name is an unexpected form, do not update the directory. + if not isinstance(new_name, str): + new_name = prev_name prev_avatar = prev_event.content.get("avatar_url") new_avatar = event.content.get("avatar_url") + # If the new avatar is an unexpected form, do not update the directory. + if not isinstance(new_avatar, str): + new_avatar = prev_avatar if prev_name != new_name or prev_avatar != new_avatar: await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 096619a8c2..479746c9c5 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -13,13 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET +from synapse.http.server import DirectServeJsonResource -from synapse.http.server import wrap_json_request_handler - -class AdditionalResource(Resource): +class AdditionalResource(DirectServeJsonResource): """Resource wrapper for additional_resources If the user has configured additional_resources, we need to wrap the @@ -41,16 +38,10 @@ class AdditionalResource(Resource): handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): function to be called to handle the request. """ - Resource.__init__(self) + super().__init__() self._handler = handler - # required by the request_handler wrapper - self.clock = hs.get_clock() - - def render(self, request): - self._async_render(request) - return NOT_DONE_YET - - @wrap_json_request_handler def _async_render(self, request): + # Cheekily pass the result straight through, so we don't need to worry + # if its an awaitable or not. return self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3cef747a4d..13fcab3378 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,13 +15,11 @@ # limitations under the License. import logging +import urllib from io import BytesIO -from six import raise_from, text_type -from six.moves import urllib - import treq -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -33,6 +31,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, ) +from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody @@ -48,6 +47,7 @@ from synapse.http import ( from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred logger = logging.getLogger(__name__) @@ -71,7 +71,22 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): return False -class IPBlacklistingResolver(object): +_EPSILON = 0.00000001 + + +def _make_scheduler(reactor): + """Makes a schedular suitable for a Cooperator using the given reactor. + + (This is effectively just a copy from `twisted.internet.task`) + """ + + def _scheduler(x): + return reactor.callLater(_EPSILON, x) + + return _scheduler + + +class IPBlacklistingResolver: """ A proxy for reactor.nameResolver which only produces non-blacklisted IP addresses, preventing DNS rebinding attacks on URL preview. @@ -118,7 +133,7 @@ class IPBlacklistingResolver(object): r.resolutionComplete() @provider(IResolutionReceiver) - class EndpointReceiver(object): + class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress): pass @@ -177,7 +192,7 @@ class BlacklistingAgentWrapper(Agent): ) -class SimpleHttpClient(object): +class SimpleHttpClient: """ A simple, no-frills HTTP client with methods that wrap up common ways of using HTTP in Matrix @@ -214,6 +229,10 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) + # We use this for our body producers to ensure that they use the correct + # reactor. + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: @@ -225,7 +244,7 @@ class SimpleHttpClient(object): ) @implementer(IReactorPluggableNameResolver) - class Reactor(object): + class Reactor: def __getattr__(_self, attr): if attr == "nameResolver": return nameResolver @@ -266,8 +285,7 @@ class SimpleHttpClient(object): ip_blacklist=self._ip_blacklist, ) - @defer.inlineCallbacks - def request(self, method, uri, data=None, headers=None): + async def request(self, method, uri, data=None, headers=None): """ Args: method (str): HTTP method to use. @@ -280,7 +298,7 @@ class SimpleHttpClient(object): outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) - logger.info("Sending request %s %s", method, redact_uri(uri)) + logger.debug("Sending request %s %s", method, redact_uri(uri)) with start_active_span( "outgoing-client-request", @@ -294,7 +312,9 @@ class SimpleHttpClient(object): try: body_producer = None if data is not None: - body_producer = QuieterFileBodyProducer(BytesIO(data)) + body_producer = QuieterFileBodyProducer( + BytesIO(data), cooperator=self._cooperator, + ) request_deferred = treq.request( method, @@ -310,7 +330,7 @@ class SimpleHttpClient(object): self.hs.get_reactor(), cancelled_to_request_timed_out_error, ) - response = yield make_deferred_yieldable(request_deferred) + response = await make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() logger.info( @@ -333,8 +353,7 @@ class SimpleHttpClient(object): set_tag("error_reason", e.args[0]) raise - @defer.inlineCallbacks - def post_urlencoded_get_json(self, uri, args={}, headers=None): + async def post_urlencoded_get_json(self, uri, args={}, headers=None): """ Args: uri (str): @@ -343,7 +362,7 @@ class SimpleHttpClient(object): header name to a list of values for that header Returns: - Deferred[object]: parsed json + object: parsed json Raises: HttpResponseException: On a non-2xx HTTP response. @@ -366,19 +385,20 @@ class SimpleHttpClient(object): if headers: actual_headers.update(headers) - response = yield self.request( + response = await self.request( "POST", uri, headers=Headers(actual_headers), data=query_bytes ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json_decoder.decode(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) - @defer.inlineCallbacks - def post_json_get_json(self, uri, post_json, headers=None): + async def post_json_get_json(self, uri, post_json, headers=None): """ Args: @@ -388,7 +408,7 @@ class SimpleHttpClient(object): header name to a list of values for that header Returns: - Deferred[object]: parsed json + object: parsed json Raises: HttpResponseException: On a non-2xx HTTP response. @@ -407,19 +427,20 @@ class SimpleHttpClient(object): if headers: actual_headers.update(headers) - response = yield self.request( + response = await self.request( "POST", uri, headers=Headers(actual_headers), data=json_str ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json_decoder.decode(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) - @defer.inlineCallbacks - def get_json(self, uri, args={}, headers=None): + async def get_json(self, uri, args={}, headers=None): """ Gets some json from the given URI. Args: @@ -431,7 +452,7 @@ class SimpleHttpClient(object): headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: - Deferred: Succeeds when we get *any* 2xx HTTP response, with the + Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: HttpResponseException On a non-2xx HTTP response. @@ -442,11 +463,10 @@ class SimpleHttpClient(object): if headers: actual_headers.update(headers) - body = yield self.get_raw(uri, args, headers=headers) - return json.loads(body) + body = await self.get_raw(uri, args, headers=headers) + return json_decoder.decode(body.decode("utf-8")) - @defer.inlineCallbacks - def put_json(self, uri, json_body, args={}, headers=None): + async def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. Args: @@ -459,7 +479,7 @@ class SimpleHttpClient(object): headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: - Deferred: Succeeds when we get *any* 2xx HTTP response, with the + Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: HttpResponseException On a non-2xx HTTP response. @@ -480,19 +500,20 @@ class SimpleHttpClient(object): if headers: actual_headers.update(headers) - response = yield self.request( + response = await self.request( "PUT", uri, headers=Headers(actual_headers), data=json_str ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body) + return json_decoder.decode(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) - @defer.inlineCallbacks - def get_raw(self, uri, args={}, headers=None): + async def get_raw(self, uri, args={}, headers=None): """ Gets raw text from the given URI. Args: @@ -504,8 +525,8 @@ class SimpleHttpClient(object): headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: - Deferred: Succeeds when we get *any* 2xx HTTP response, with the - HTTP body at text. + Succeeds when we get *any* 2xx HTTP response, with the + HTTP body as bytes. Raises: HttpResponseException on a non-2xx HTTP response. """ @@ -517,20 +538,21 @@ class SimpleHttpClient(object): if headers: actual_headers.update(headers) - response = yield self.request("GET", uri, headers=Headers(actual_headers)) + response = await self.request("GET", uri, headers=Headers(actual_headers)) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: return body else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. - @defer.inlineCallbacks - def get_file(self, url, output_stream, max_size=None, headers=None): + async def get_file(self, url, output_stream, max_size=None, headers=None): """GETs a file from a given URL Args: url (str): The URL to GET @@ -546,7 +568,7 @@ class SimpleHttpClient(object): if headers: actual_headers.update(headers) - response = yield self.request("GET", url, headers=Headers(actual_headers)) + response = await self.request("GET", url, headers=Headers(actual_headers)) resp_headers = dict(response.headers.getAllRawHeaders()) @@ -570,14 +592,14 @@ class SimpleHttpClient(object): # straight back in again try: - length = yield make_deferred_yieldable( + length = await make_deferred_yieldable( _readBodyToFile(response, output_stream, max_size) ) except SynapseError: # This can happen e.g. because the body is too large. raise except Exception as e: - raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) + raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e return ( length, @@ -638,7 +660,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): - if isinstance(arg, text_type): + if isinstance(arg, str): return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index be7b2ceb8e..856e28454f 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -31,7 +31,7 @@ class ProxyConnectError(ConnectError): @implementer(IStreamClientEndpoint) -class HTTPConnectProxyEndpoint(object): +class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy Wraps an existing HostnameEndpoint for the proxy. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index f5f917f5ae..83d6196d4a 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -15,6 +15,7 @@ import logging import urllib +from typing import List from netaddr import AddrFormatError, IPAddress from zope.interface import implementer @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) @implementer(IAgent) -class MatrixFederationAgent(object): +class MatrixFederationAgent: """An Agent-like thing which provides a `request` method which correctly handles resolving matrix server names when using matrix://. Handles standard https URIs as normal. @@ -48,6 +49,9 @@ class MatrixFederationAgent(object): tls_client_options_factory (FederationPolicyForHTTPS|None): factory to use for fetching client tls options, or none to disable TLS. + user_agent (bytes): + The user agent header to use for federation requests. + _srv_resolver (SrvResolver|None): SRVResolver impl to use for looking up SRV records. None to use a default implementation. @@ -61,6 +65,7 @@ class MatrixFederationAgent(object): self, reactor, tls_client_options_factory, + user_agent, _srv_resolver=None, _well_known_resolver=None, ): @@ -78,6 +83,7 @@ class MatrixFederationAgent(object): ), pool=self._pool, ) + self.user_agent = user_agent if _well_known_resolver is None: _well_known_resolver = WellKnownResolver( @@ -87,6 +93,7 @@ class MatrixFederationAgent(object): pool=self._pool, contextFactory=tls_client_options_factory, ), + user_agent=self.user_agent, ) self._well_known_resolver = _well_known_resolver @@ -127,8 +134,8 @@ class MatrixFederationAgent(object): and not _is_ip_literal(parsed_uri.hostname) and not parsed_uri.port ): - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.hostname + well_known_result = yield defer.ensureDeferred( + self._well_known_resolver.get_well_known(parsed_uri.hostname) ) delegated_server = well_known_result.delegated_server @@ -149,7 +156,7 @@ class MatrixFederationAgent(object): parsed_uri = urllib.parse.urlparse(uri) # We need to make sure the host header is set to the netloc of the - # server. + # server and that a user-agent is provided. if headers is None: headers = Headers() else: @@ -157,6 +164,8 @@ class MatrixFederationAgent(object): if not headers.hasHeader(b"host"): headers.addRawHeader(b"host", parsed_uri.netloc) + if not headers.hasHeader(b"user-agent"): + headers.addRawHeader(b"user-agent", self.user_agent) res = yield make_deferred_yieldable( self._agent.request(method, uri, headers, bodyProducer) @@ -166,7 +175,7 @@ class MatrixFederationAgent(object): @implementer(IAgentEndpointFactory) -class MatrixHostnameEndpointFactory(object): +class MatrixHostnameEndpointFactory: """Factory for MatrixHostnameEndpoint for parsing to an Agent. """ @@ -189,7 +198,7 @@ class MatrixHostnameEndpointFactory(object): @implementer(IStreamClientEndpoint) -class MatrixHostnameEndpoint(object): +class MatrixHostnameEndpoint: """An endpoint that resolves matrix:// URLs using Matrix server name resolution (i.e. via SRV). Does not check for well-known delegation. @@ -228,22 +237,21 @@ class MatrixHostnameEndpoint(object): return run_in_background(self._do_connect, protocol_factory) - @defer.inlineCallbacks - def _do_connect(self, protocol_factory): + async def _do_connect(self, protocol_factory): first_exception = None - server_list = yield self._resolve_server() + server_list = await self._resolve_server() for server in server_list: host = server.host port = server.port try: - logger.info("Connecting to %s:%i", host.decode("ascii"), port) + logger.debug("Connecting to %s:%i", host.decode("ascii"), port) endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) - result = yield make_deferred_yieldable( + result = await make_deferred_yieldable( endpoint.connect(protocol_factory) ) @@ -263,13 +271,9 @@ class MatrixHostnameEndpoint(object): # to try and if that doesn't work then we'll have an exception. raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - @defer.inlineCallbacks - def _resolve_server(self): + async def _resolve_server(self) -> List[Server]: """Resolves the server name to a list of hosts and ports to attempt to connect to. - - Returns: - Deferred[list[Server]] """ if self._parsed_uri.scheme != b"matrix": @@ -290,7 +294,7 @@ class MatrixHostnameEndpoint(object): if port or _is_ip_literal(host): return [Server(host, port or 8448)] - server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) + server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: return server_list diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 021b233a7d..d9620032d2 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -17,10 +17,10 @@ import logging import random import time +from typing import List import attr -from twisted.internet import defer from twisted.internet.error import ConnectError from twisted.names import client, dns from twisted.names.error import DNSNameError, DomainError @@ -33,7 +33,7 @@ SERVER_CACHE = {} @attr.s(slots=True, frozen=True) -class Server(object): +class Server: """ Our record of an individual server which can be tried to reach a destination. @@ -96,7 +96,7 @@ def _sort_server_list(server_list): return results -class SrvResolver(object): +class SrvResolver: """Interface to the dns client to do SRV lookups, with result caching. The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, @@ -113,16 +113,14 @@ class SrvResolver(object): self._cache = cache self._get_time = get_time - @defer.inlineCallbacks - def resolve_service(self, service_name): + async def resolve_service(self, service_name: bytes) -> List[Server]: """Look up a SRV record Args: service_name (bytes): record to look up Returns: - Deferred[list[Server]]: - a list of the SRV records, or an empty list if none found + a list of the SRV records, or an empty list if none found """ now = int(self._get_time()) @@ -136,7 +134,7 @@ class SrvResolver(object): return _sort_server_list(servers) try: - answers, _, _ = yield make_deferred_yieldable( + answers, _, _ = await make_deferred_yieldable( self._dns_client.lookupService(service_name) ) except DNSNameError: diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 7ddfad286d..e6f067ca29 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -13,19 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import random import time +from typing import Callable, Dict, Optional, Tuple import attr from twisted.internet import defer from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.logging.context import make_deferred_yieldable -from synapse.util import Clock +from synapse.util import Clock, json_decoder from synapse.util.caches.ttlcache import TTLCache from synapse.util.metrics import Measure @@ -69,16 +71,21 @@ _had_valid_well_known_cache = TTLCache("had-valid-well-known") @attr.s(slots=True, frozen=True) -class WellKnownLookupResult(object): +class WellKnownLookupResult: delegated_server = attr.ib() -class WellKnownResolver(object): +class WellKnownResolver: """Handles well-known lookups for matrix servers. """ def __init__( - self, reactor, agent, well_known_cache=None, had_well_known_cache=None + self, + reactor, + agent, + user_agent, + well_known_cache=None, + had_well_known_cache=None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -92,16 +99,16 @@ class WellKnownResolver(object): self._well_known_cache = well_known_cache self._had_valid_well_known_cache = had_well_known_cache self._well_known_agent = RedirectAgent(agent) + self.user_agent = user_agent - @defer.inlineCallbacks - def get_well_known(self, server_name): + async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult: """Attempt to fetch and parse a .well-known file for the given server Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Returns: - Deferred[WellKnownLookupResult]: The result of the lookup + The result of the lookup """ try: prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( @@ -118,7 +125,9 @@ class WellKnownResolver(object): # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._fetch_well_known(server_name) + result, cache_period = await self._fetch_well_known( + server_name + ) # type: Tuple[Optional[bytes], float] except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -147,18 +156,17 @@ class WellKnownResolver(object): return WellKnownLookupResult(delegated_server=result) - @defer.inlineCallbacks - def _fetch_well_known(self, server_name): + async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]: """Actually fetch and parse a .well-known, without checking the cache Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Raises: _FetchWellKnownFailure if we fail to lookup a result Returns: - Deferred[Tuple[bytes,int]]: The lookup result and cache period. + The lookup result and cache period. """ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) @@ -166,7 +174,7 @@ class WellKnownResolver(object): # We do this in two steps to differentiate between possibly transient # errors (e.g. can't connect to host, 503 response) and more permenant # errors (such as getting a 404 response). - response, body = yield self._make_well_known_request( + response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known ) @@ -174,7 +182,7 @@ class WellKnownResolver(object): if response.code != 200: raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode("utf-8")) + parsed_body = json_decoder.decode(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) result = parsed_body["m.server"].encode("ascii") @@ -209,34 +217,40 @@ class WellKnownResolver(object): return result, cache_period - @defer.inlineCallbacks - def _make_well_known_request(self, server_name, retry): + async def _make_well_known_request( + self, server_name: bytes, retry: bool + ) -> Tuple[IResponse, bytes]: """Make the well known request. This will retry the request if requested and it fails (with unable to connect or receives a 5xx error). Args: - server_name (bytes) - retry (bool): Whether to retry the request if it fails. + server_name: name of the server, from the requested url + retry: Whether to retry the request if it fails. Returns: - Deferred[tuple[IResponse, bytes]] Returns the response object and - body. Response may be a non-200 response. + Returns the response object and body. Response may be a non-200 response. """ uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") + headers = { + b"User-Agent": [self.user_agent], + } + i = 0 while True: i += 1 logger.info("Fetching %s", uri_str) try: - response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri) + response = await make_deferred_yieldable( + self._well_known_agent.request( + b"GET", uri, headers=Headers(headers) + ) ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -253,21 +267,24 @@ class WellKnownResolver(object): logger.info("Error fetching %s: %s. Retrying", uri_str, e) # Sleep briefly in the hopes that they come back up - yield self._clock.sleep(0.5) + await self._clock.sleep(0.5) -def _cache_period_from_headers(headers, time_now=time.time): +def _cache_period_from_headers( + headers: Headers, time_now: Callable[[], float] = time.time +) -> Optional[float]: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: return 0 if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass + max_age = cache_controls[b"max-age"] + if max_age: + try: + return int(max_age) + except ValueError: + pass expires = headers.getRawHeaders(b"expires") if expires is not None: @@ -283,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time): return None -def _parse_cache_control(headers): +def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} for hdr in headers.getRawHeaders(b"cache-control", []): for directive in hdr.split(b","): diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2d47b9ea00..5eaf3151ce 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,11 +17,9 @@ import cgi import logging import random import sys +import urllib from io import BytesIO -from six import raise_from, string_types -from six.moves import urllib - import attr import treq from canonicaljson import encode_canonical_json @@ -31,10 +29,11 @@ from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver +from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse import synapse.metrics import synapse.util.retryutils @@ -55,6 +54,7 @@ from synapse.logging.opentracing import ( start_active_span, tags, ) +from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -76,8 +76,8 @@ MAXINT = sys.maxsize _next_id = 1 -@attr.s -class MatrixFederationRequest(object): +@attr.s(frozen=True) +class MatrixFederationRequest: method = attr.ib() """HTTP method :type: str @@ -112,27 +112,52 @@ class MatrixFederationRequest(object): :type: str|None """ + uri = attr.ib(init=False, type=bytes) + """The URI of this request + """ + def __attrs_post_init__(self): global _next_id - self.txn_id = "%s-O-%s" % (self.method, _next_id) + txn_id = "%s-O-%s" % (self.method, _next_id) _next_id = (_next_id + 1) % (MAXINT - 1) + object.__setattr__(self, "txn_id", txn_id) + + destination_bytes = self.destination.encode("ascii") + path_bytes = self.path.encode("ascii") + if self.query: + query_bytes = encode_query_args(self.query) + else: + query_bytes = b"" + + # The object is frozen so we can pre-compute this. + uri = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + object.__setattr__(self, "uri", uri) + def get_json(self): if self.json_callback: return self.json_callback() return self.json -@defer.inlineCallbacks -def _handle_json_response(reactor, timeout_sec, request, response): +async def _handle_json_response( + reactor: IReactorTime, + timeout_sec: float, + request: MatrixFederationRequest, + response: IResponse, + start_ms: int, +): """ Reads the JSON body of a response, with a timeout Args: - reactor (IReactor): twisted reactor, for the timeout - timeout_sec (float): number of seconds to wait for response to complete - request (MatrixFederationRequest): the request that triggered the response - response (IResponse): response to the request + reactor: twisted reactor, for the timeout + timeout_sec: number of seconds to wait for response to complete + request: the request that triggered the response + response: response to the request + start_ms: Timestamp when request was made Returns: dict: parsed JSON response @@ -140,34 +165,48 @@ def _handle_json_response(reactor, timeout_sec, request, response): try: check_content_type_is_json(response.headers) - d = treq.json_content(response) + # Use the custom JSON decoder (partially re-implements treq.json_content). + d = treq.text_content(response, encoding="utf-8") + d.addCallback(json_decoder.decode) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except TimeoutError as e: logger.warning( - "{%s} [%s] Timed out reading response", request.txn_id, request.destination, + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), ) raise RequestSendFailed(e, can_retry=True) from e except Exception as e: logger.warning( - "{%s} [%s] Error reading response: %s", + "{%s} [%s] Error reading response %s %s: %s", request.txn_id, request.destination, + request.method, + request.uri.decode("ascii"), e, ) raise + + time_taken_secs = reactor.seconds() - start_ms / 1000 + logger.info( - "{%s} [%s] Completed: %d %s", + "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), + time_taken_secs, + request.method, + request.uri.decode("ascii"), ) return body -class MatrixFederationHttpClient(object): +class MatrixFederationHttpClient: """HTTP client used to talk to other homeservers over the federation protocol. Send client certificates and signs requests. @@ -178,7 +217,7 @@ class MatrixFederationHttpClient(object): def __init__(self, hs, tls_client_options_factory): self.hs = hs - self.signing_key = hs.config.signing_key[0] + self.signing_key = hs.signing_key self.server_name = hs.hostname real_reactor = hs.get_reactor() @@ -190,7 +229,7 @@ class MatrixFederationHttpClient(object): ) @implementer(IReactorPluggableNameResolver) - class Reactor(object): + class Reactor: def __getattr__(_self, attr): if attr == "nameResolver": return nameResolver @@ -199,7 +238,14 @@ class MatrixFederationHttpClient(object): self.reactor = Reactor() - self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory) + user_agent = hs.version_string + if hs.config.user_agent_suffix: + user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) + user_agent = user_agent.encode("ascii") + + self.agent = MatrixFederationAgent( + self.reactor, tls_client_options_factory, user_agent + ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names @@ -219,8 +265,7 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) - @defer.inlineCallbacks - def _send_request_with_optional_trailing_slash( + async def _send_request_with_optional_trailing_slash( self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request @@ -241,10 +286,10 @@ class MatrixFederationHttpClient(object): (except 429). Returns: - Deferred[Dict]: Parsed JSON response body. + Dict: Parsed JSON response body. """ try: - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -258,14 +303,15 @@ class MatrixFederationHttpClient(object): # 'M_UNRECOGNIZED' which some endpoints can return when omitting a # trailing slash on Synapse <= v0.99.3. logger.info("Retrying request with trailing slash") - request.path += "/" - response = yield self._send_request(request, **send_request_args) + # Request is frozen so we create a new instance + request = attr.evolve(request, path=request.path + "/") + + response = await self._send_request(request, **send_request_args) return response - @defer.inlineCallbacks - def _send_request( + async def _send_request( self, request, retry_on_dns_fail=True, @@ -306,7 +352,7 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred[twisted.web.client.Response]: resolves with the HTTP + twisted.web.client.Response: resolves with the HTTP response object on success. Raises: @@ -330,7 +376,7 @@ class MatrixFederationHttpClient(object): ): raise FederationDeniedError(request.destination) - limiter = yield synapse.util.retryutils.get_retry_limiter( + limiter = await synapse.util.retryutils.get_retry_limiter( request.destination, self.clock, self._store, @@ -371,9 +417,7 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse( - (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") - ) + url_bytes = request.uri url_str = url_bytes.decode("ascii") url_to_sign_bytes = urllib.parse.urlunparse( @@ -400,7 +444,7 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers - logger.info( + logger.debug( "{%s} [%s] Sending request: %s %s; timeout %fs", request.txn_id, request.destination, @@ -428,20 +472,20 @@ class MatrixFederationHttpClient(object): reactor=self.reactor, ) - response = yield request_deferred + response = await request_deferred except TimeoutError as e: raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: - raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) + raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: - logger.info("Failed to send request: %s", e) - raise_from(RequestSendFailed(e, can_retry=True), e) + raise RequestSendFailed(e, can_retry=True) from e incoming_responses_counter.labels( request.method, response.code ).inc() set_tag(tags.HTTP_STATUS_CODE, response.code) + response_phrase = response.phrase.decode("ascii", errors="replace") if 200 <= response.code < 300: logger.debug( @@ -449,7 +493,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode("ascii", errors="replace"), + response_phrase, ) pass else: @@ -458,7 +502,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode("ascii", errors="replace"), + response_phrase, ) # :'( # Update transactions table? @@ -468,7 +512,7 @@ class MatrixFederationHttpClient(object): ) try: - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. @@ -482,18 +526,18 @@ class MatrixFederationHttpClient(object): ) body = None - e = HttpResponseException(response.code, response.phrase, body) + e = HttpResponseException(response.code, response_phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException if response.code == 429: - raise_from(RequestSendFailed(e, can_retry=True), e) + raise RequestSendFailed(e, can_retry=True) from e else: raise e break except RequestSendFailed as e: - logger.warning( + logger.info( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -522,7 +566,7 @@ class MatrixFederationHttpClient(object): delay, ) - yield self.clock.sleep(delay) + await self.clock.sleep(delay) retries_left -= 1 else: raise @@ -557,13 +601,17 @@ class MatrixFederationHttpClient(object): Returns: list[bytes]: a list of headers to be added as "Authorization:" headers """ - request = {"method": method, "uri": url_bytes, "origin": self.server_name} + request = { + "method": method.decode("ascii"), + "uri": url_bytes.decode("ascii"), + "origin": self.server_name, + } if destination is not None: - request["destination"] = destination + request["destination"] = destination.decode("ascii") if destination_is is not None: - request["destination_is"] = destination_is + request["destination_is"] = destination_is.decode("ascii") if content is not None: request["content"] = content @@ -581,8 +629,7 @@ class MatrixFederationHttpClient(object): ) return auth_headers - @defer.inlineCallbacks - def put_json( + async def put_json( self, destination, path, @@ -626,7 +673,7 @@ class MatrixFederationHttpClient(object): enabled. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -648,7 +695,9 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request_with_optional_trailing_slash( + start_ms = self.clock.time_msec() + + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=backoff_on_404, @@ -657,14 +706,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response + body = await _handle_json_response( + self.reactor, self.default_timeout, request, response, start_ms ) return body - @defer.inlineCallbacks - def post_json( + async def post_json( self, destination, path, @@ -697,7 +745,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -715,7 +763,9 @@ class MatrixFederationHttpClient(object): method="POST", destination=destination, path=path, query=args, json=data ) - response = yield self._send_request( + start_ms = self.clock.time_msec() + + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, @@ -727,13 +777,12 @@ class MatrixFederationHttpClient(object): else: _sec_timeout = self.default_timeout - body = yield _handle_json_response( - self.reactor, _sec_timeout, request, response + body = await _handle_json_response( + self.reactor, _sec_timeout, request, response, start_ms, ) return body - @defer.inlineCallbacks - def get_json( + async def get_json( self, destination, path, @@ -765,7 +814,7 @@ class MatrixFederationHttpClient(object): response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -782,7 +831,9 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request_with_optional_trailing_slash( + start_ms = self.clock.time_msec() + + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=False, @@ -791,14 +842,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response + body = await _handle_json_response( + self.reactor, self.default_timeout, request, response, start_ms ) return body - @defer.inlineCallbacks - def delete_json( + async def delete_json( self, destination, path, @@ -826,7 +876,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -843,20 +893,21 @@ class MatrixFederationHttpClient(object): method="DELETE", destination=destination, path=path, query=args ) - response = yield self._send_request( + start_ms = self.clock.time_msec() + + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response + body = await _handle_json_response( + self.reactor, self.default_timeout, request, response, start_ms ) return body - @defer.inlineCallbacks - def get_file( + async def get_file( self, destination, path, @@ -876,7 +927,7 @@ class MatrixFederationHttpClient(object): and try the request anyway. Returns: - Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + tuple[int, dict]: Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -893,7 +944,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) @@ -902,7 +953,7 @@ class MatrixFederationHttpClient(object): try: d = _readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) - length = yield make_deferred_yieldable(d) + length = await make_deferred_yieldable(d) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", @@ -912,12 +963,14 @@ class MatrixFederationHttpClient(object): ) raise logger.info( - "{%s} [%s] Completed: %d %s [%d bytes]", + "{%s} [%s] Completed: %d %s [%d bytes] %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), length, + request.method, + request.uri.decode("ascii"), ) return (length, headers) @@ -998,7 +1051,7 @@ def encode_query_args(args): encoded_args = {} for k, vs in args.items(): - if isinstance(vs, string_types): + if isinstance(vs, str): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index b58ae3d9db..cd94e789e8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -145,7 +145,7 @@ LaterGauge( ) -class RequestMetrics(object): +class RequestMetrics: def start(self, time_sec, name, method): self.start = time_sec self.start_context = current_context() diff --git a/synapse/http/server.py b/synapse/http/server.py index 2487a72171..996a31a9ec 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,23 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import collections import html -import http.client import logging import types import urllib +from http import HTTPStatus from io import BytesIO -from typing import Awaitable, Callable, TypeVar, Union +from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json +from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from zope.interface import implementer -from twisted.internet import defer +from twisted.internet import defer, interfaces from twisted.python import failure from twisted.web import resource from twisted.web.server import NOT_DONE_YET, Request -from twisted.web.static import NoRangeStaticProducer +from twisted.web.static import File, NoRangeStaticProducer from twisted.web.util import redirectTo import synapse.events @@ -45,6 +47,7 @@ from synapse.api.errors import ( from synapse.http.site import SynapseRequest from synapse.logging.context import preserve_fn from synapse.logging.opentracing import trace_servlet +from synapse.util import json_encoder from synapse.util.caches import intern_dict logger = logging.getLogger(__name__) @@ -62,99 +65,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> """ -def wrap_json_request_handler(h): - """Wraps a request handler method with exception handling. - - Also does the wrapping with request.processing as per wrap_async_request_handler. - - The handler method must have a signature of "handle_foo(self, request)", - where "request" must be a SynapseRequest. - - The handler must return a deferred or a coroutine. If the deferred succeeds - we assume that a response has been sent. If the deferred fails with a SynapseError we use - it to send a JSON response with the appropriate HTTP reponse code. If the - deferred fails with any other type of error we send a 500 reponse. +def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: + """Sends a JSON error response to clients. """ - async def wrapped_request_handler(self, request): - try: - await h(self, request) - except SynapseError as e: - code = e.code - logger.info("%s SynapseError: %s - %s", request, code, e.msg) - - # Only respond with an error response if we haven't already started - # writing, otherwise lets just kill the connection - if request.startedWriting: - if request.transport: - try: - request.transport.abortConnection() - except Exception: - # abortConnection throws if the connection is already closed - pass - else: - respond_with_json( - request, - code, - e.error_dict(), - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) - - except Exception: - # failure.Failure() fishes the original Failure out - # of our stack, and thus gives us a sensible stack - # trace. - f = failure.Failure() - logger.error( - "Failed handle request via %r: %r", - request.request_metrics.name, - request, - exc_info=(f.type, f.value, f.getTracebackObject()), - ) - # Only respond with an error response if we haven't already started - # writing, otherwise lets just kill the connection - if request.startedWriting: - if request.transport: - try: - request.transport.abortConnection() - except Exception: - # abortConnection throws if the connection is already closed - pass - else: - respond_with_json( - request, - 500, - {"error": "Internal server error", "errcode": Codes.UNKNOWN}, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), - ) - - return wrap_async_request_handler(wrapped_request_handler) - + if f.check(SynapseError): + error_code = f.value.code + error_dict = f.value.error_dict() -TV = TypeVar("TV") - - -def wrap_html_request_handler( - h: Callable[[TV, SynapseRequest], Awaitable] -) -> Callable[[TV, SynapseRequest], Awaitable[None]]: - """Wraps a request handler method with exception handling. - - Also does the wrapping with request.processing as per wrap_async_request_handler. - - The handler method must have a signature of "handle_foo(self, request)", - where "request" must be a SynapseRequest. - """ + logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg) + else: + error_code = 500 + error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} - async def wrapped_request_handler(self, request): - try: - await h(self, request) - except Exception: - f = failure.Failure() - return_html_error(f, request, HTML_ERROR_TEMPLATE) + logger.error( + "Failed handle request via %r: %r", + request.request_metrics.name, + request, + exc_info=(f.type, f.value, f.getTracebackObject()), + ) - return wrap_async_request_handler(wrapped_request_handler) + # Only respond with an error response if we haven't already started writing, + # otherwise lets just kill the connection + if request.startedWriting: + if request.transport: + try: + request.transport.abortConnection() + except Exception: + # abortConnection throws if the connection is already closed + pass + else: + respond_with_json( + request, + error_code, + error_dict, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + ) def return_html_error( @@ -188,7 +135,7 @@ def return_html_error( exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http.HTTPStatus.INTERNAL_SERVER_ERROR + code = HTTPStatus.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( @@ -202,12 +149,7 @@ def return_html_error( else: body = error_template.render(code=code, msg=msg) - body_bytes = body.encode("utf-8") - request.setResponseCode(code) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),)) - request.write(body_bytes) - finish_request(request) + respond_with_html(request, code, body) def wrap_async_request_handler(h): @@ -232,7 +174,7 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer(object): +class HttpServer: """ Interface for registering callbacks on a HTTP server """ @@ -254,7 +196,115 @@ class HttpServer(object): pass -class JsonResource(HttpServer, resource.Resource): +class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): + """Base class for resources that have async handlers. + + Sub classes can either implement `_async_render_<METHOD>` to handle + requests by method, or override `_async_render` to handle all requests. + + Args: + extract_context: Whether to attempt to extract the opentracing + context from the request the servlet is handling. + """ + + def __init__(self, extract_context=False): + super().__init__() + + self._extract_context = extract_context + + def render(self, request): + """ This gets called by twisted every time someone sends us a request. + """ + defer.ensureDeferred(self._async_render_wrapper(request)) + return NOT_DONE_YET + + @wrap_async_request_handler + async def _async_render_wrapper(self, request: SynapseRequest): + """This is a wrapper that delegates to `_async_render` and handles + exceptions, return values, metrics, etc. + """ + try: + request.request_metrics.name = self.__class__.__name__ + + with trace_servlet(request, self._extract_context): + callback_return = await self._async_render(request) + + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) + except Exception: + # failure.Failure() fishes the original Failure out + # of our stack, and thus gives us a sensible stack + # trace. + f = failure.Failure() + self._send_error_response(f, request) + + async def _async_render(self, request: Request): + """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if + no appropriate method exists. Can be overriden in sub classes for + different routing. + """ + # Treat HEAD requests as GET requests. + request_method = request.method.decode("ascii") + if request_method == "HEAD": + request_method = "GET" + + method_handler = getattr(self, "_async_render_%s" % (request_method,), None) + if method_handler: + raw_callback_return = method_handler(request) + + # Is it synchronous? We'll allow this for now. + if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await raw_callback_return + else: + callback_return = raw_callback_return + + return callback_return + + _unrecognised_request_handler(request) + + @abc.abstractmethod + def _send_response( + self, request: SynapseRequest, code: int, response_object: Any, + ) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def _send_error_response( + self, f: failure.Failure, request: SynapseRequest, + ) -> None: + raise NotImplementedError() + + +class DirectServeJsonResource(_AsyncResource): + """A resource that will call `self._async_on_<METHOD>` on new requests, + formatting responses and errors as JSON. + """ + + def _send_response( + self, request: Request, code: int, response_object: Any, + ): + """Implements _AsyncResource._send_response + """ + # TODO: Only enable CORS for the requests that need it. + respond_with_json( + request, + code, + response_object, + send_cors=True, + pretty_print=_request_user_agent_is_curl(request), + canonical_json=self.canonical_json, + ) + + def _send_error_response( + self, f: failure.Failure, request: SynapseRequest, + ) -> None: + """Implements _AsyncResource._send_error_response + """ + return_json_error(f, request) + + +class JsonResource(DirectServeJsonResource): """ This implements the HttpServer interface and provides JSON support for Resources. @@ -274,17 +324,15 @@ class JsonResource(HttpServer, resource.Resource): "_PathEntry", ["pattern", "callback", "servlet_classname"] ) - def __init__(self, hs, canonical_json=True): - resource.Resource.__init__(self) + def __init__(self, hs, canonical_json=True, extract_context=False): + super().__init__(extract_context) self.canonical_json = canonical_json self.clock = hs.get_clock() self.path_regexs = {} self.hs = hs - def register_paths( - self, method, path_patterns, callback, servlet_classname, trace=True - ): + def register_paths(self, method, path_patterns, callback, servlet_classname): """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate @@ -300,37 +348,46 @@ class JsonResource(HttpServer, resource.Resource): servlet_classname (str): The name of the handler to be used in prometheus and opentracing logs. - - trace (bool): Whether we should start a span to trace the servlet. """ method = method.encode("utf-8") # method is bytes on py3 - if trace: - # We don't extract the context from the servlet because we can't - # trust the sender - callback = trace_servlet(servlet_classname)(callback) - for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( self._PathEntry(path_pattern, callback, servlet_classname) ) - def render(self, request): - """ This gets called by twisted every time someone sends us a request. + def _get_handler_for_request( + self, request: SynapseRequest + ) -> Tuple[Callable, str, Dict[str, str]]: + """Finds a callback method to handle the given request. + + Returns: + A tuple of the callback to use, the name of the servlet, and the + key word arguments to pass to the callback """ - defer.ensureDeferred(self._async_render(request)) - return NOT_DONE_YET + # Treat HEAD requests as GET requests. + request_path = request.path.decode("ascii") + request_method = request.method + if request_method == b"HEAD": + request_method = b"GET" + + # Loop through all the registered callbacks to check if the method + # and path regex match + for path_entry in self.path_regexs.get(request_method, []): + m = path_entry.pattern.match(request_path) + if m: + # We found a match! + return path_entry.callback, path_entry.servlet_classname, m.groupdict() + + # Huh. No one wanted to handle that? Fiiiiiine. Send 400. + return _unrecognised_request_handler, "unrecognised_request_handler", {} - @wrap_json_request_handler async def _async_render(self, request): - """ This gets called from render() every time someone sends us a request. - This checks if anyone has registered a callback for that method and - path. - """ callback, servlet_classname, group_dict = self._get_handler_for_request(request) - # Make sure we have a name for this handler in prometheus. + # Make sure we have an appopriate name for this handler in prometheus + # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname # Now trigger the callback. If it returns a response, we send it @@ -343,96 +400,54 @@ class JsonResource(HttpServer, resource.Resource): } ) - callback_return = callback(request, **kwargs) + raw_callback_return = callback(request, **kwargs) # Is it synchronous? We'll allow this for now. - if isinstance(callback_return, (defer.Deferred, types.CoroutineType)): - callback_return = await callback_return - - if callback_return is not None: - code, response = callback_return - self._send_response(request, code, response) - - def _get_handler_for_request(self, request): - """Finds a callback method to handle the given request + if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await raw_callback_return + else: + callback_return = raw_callback_return - Args: - request (twisted.web.http.Request): + return callback_return - Returns: - Tuple[Callable, str, dict[unicode, unicode]]: callback method, the - label to use for that method in prometheus metrics, and the - dict mapping keys to path components as specified in the - handler's path match regexp. - - The callback will normally be a method registered via - register_paths, so will return (possibly via Deferred) either - None, or a tuple of (http code, response body). - """ - request_path = request.path.decode("ascii") - # Loop through all the registered callbacks to check if the method - # and path regex match - for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request_path) - if m: - # We found a match! - return path_entry.callback, path_entry.servlet_classname, m.groupdict() +class DirectServeHtmlResource(_AsyncResource): + """A resource that will call `self._async_on_<METHOD>` on new requests, + formatting responses and errors as HTML. + """ - # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - return _unrecognised_request_handler, "unrecognised_request_handler", {} + # The error template to use for this resource + ERROR_TEMPLATE = HTML_ERROR_TEMPLATE def _send_response( - self, request, code, response_json_object, response_code_message=None + self, request: SynapseRequest, code: int, response_object: Any, ): - # TODO: Only enable CORS for the requests that need it. - respond_with_json( - request, - code, - response_json_object, - send_cors=True, - response_code_message=response_code_message, - pretty_print=_request_user_agent_is_curl(request), - canonical_json=self.canonical_json, - ) - - -class DirectServeResource(resource.Resource): - def render(self, request): + """Implements _AsyncResource._send_response """ - Render the request, using an asynchronous render handler if it exists. - """ - async_render_callback_name = "_async_render_" + request.method.decode("ascii") - - # Try and get the async renderer - callback = getattr(self, async_render_callback_name, None) - - # No async renderer for this request method. - if not callback: - return super().render(request) - - resp = trace_servlet(self.__class__.__name__)(callback)(request) - - # If it's a coroutine, turn it into a Deferred - if isinstance(resp, types.CoroutineType): - defer.ensureDeferred(resp) - - return NOT_DONE_YET + # We expect to get bytes for us to write + assert isinstance(response_object, bytes) + html_bytes = response_object + respond_with_html_bytes(request, 200, html_bytes) -def _options_handler(request): - """Request handler for OPTIONS requests + def _send_error_response( + self, f: failure.Failure, request: SynapseRequest, + ) -> None: + """Implements _AsyncResource._send_error_response + """ + return_html_error(f, request, self.ERROR_TEMPLATE) - This is a request handler suitable for return from - _get_handler_for_request. It returns a 200 and an empty body. - Args: - request (twisted.web.http.Request): +class StaticResource(File): + """ + A resource that represents a plain non-interpreted file or directory. - Returns: - Tuple[int, dict]: http code, response body. + Differs from the File resource by adding clickjacking protection. """ - return 200, {} + + def render_GET(self, request: Request): + set_clickjacking_protection_headers(request) + return super().render_GET(request) def _unrecognised_request_handler(request): @@ -468,11 +483,12 @@ class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" def render_OPTIONS(self, request): - code, response_json_object = _options_handler(request) + request.setResponseCode(204) + request.setHeader(b"Content-Length", b"0") - return respond_with_json( - request, code, response_json_object, send_cors=True, canonical_json=False, - ) + set_cors_headers(request) + + return b"" def getChildWithDefault(self, path, request): if request.method == b"OPTIONS": @@ -484,15 +500,114 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect): pass +@implementer(interfaces.IPushProducer) +class _ByteProducer: + """ + Iteratively write bytes to the request. + """ + + # The minimum number of bytes for each chunk. Note that the last chunk will + # usually be smaller than this. + min_chunk_size = 1024 + + def __init__( + self, request: Request, iterator: Iterator[bytes], + ): + self._request = request + self._iterator = iterator + self._paused = False + + # Register the producer and start producing data. + self._request.registerProducer(self, True) + self.resumeProducing() + + def _send_data(self, data: List[bytes]) -> None: + """ + Send a list of bytes as a chunk of a response. + """ + if not data: + return + self._request.write(b"".join(data)) + + def pauseProducing(self) -> None: + self._paused = True + + def resumeProducing(self) -> None: + # We've stopped producing in the meantime (note that this might be + # re-entrant after calling write). + if not self._request: + return + + self._paused = False + + # Write until there's backpressure telling us to stop. + while not self._paused: + # Get the next chunk and write it to the request. + # + # The output of the JSON encoder is buffered and coalesced until + # min_chunk_size is reached. This is because JSON encoders produce + # very small output per iteration and the Request object converts + # each call to write() to a separate chunk. Without this there would + # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n"). + # + # Note that buffer stores a list of bytes (instead of appending to + # bytes) to hopefully avoid many allocations. + buffer = [] + buffered_bytes = 0 + while buffered_bytes < self.min_chunk_size: + try: + data = next(self._iterator) + buffer.append(data) + buffered_bytes += len(data) + except StopIteration: + # The entire JSON object has been serialized, write any + # remaining data, finalize the producer and the request, and + # clean-up any references. + self._send_data(buffer) + self._request.unregisterProducer() + self._request.finish() + self.stopProducing() + return + + self._send_data(buffer) + + def stopProducing(self) -> None: + # Clear a circular reference. + self._request = None + + +def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: + """ + Encode an object into JSON. Returns an iterator of bytes. + """ + for chunk in json_encoder.iterencode(json_object): + yield chunk.encode("utf-8") + + def respond_with_json( - request, - code, - json_object, - send_cors=False, - response_code_message=None, - pretty_print=False, - canonical_json=True, + request: Request, + code: int, + json_object: Any, + send_cors: bool = False, + pretty_print: bool = False, + canonical_json: bool = True, ): + """Sends encoded JSON in response to the given request. + + Args: + request: The http request to respond to. + code: The HTTP response code. + json_object: The object to serialize to JSON. + send_cors: Whether to send Cross-Origin Resource Sharing headers + https://fetch.spec.whatwg.org/#http-cors-protocol + pretty_print: Whether to include indentation and line-breaks in the + resulting JSON bytes. + canonical_json: Whether to use the canonicaljson algorithm when encoding + the JSON bytes. + + Returns: + twisted.web.server.NOT_DONE_YET if the request is still active. + """ # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. @@ -500,41 +615,44 @@ def respond_with_json( logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None if pretty_print: - json_bytes = encode_pretty_printed_json(json_object) + b"\n" + encoder = iterencode_pretty_printed_json else: if canonical_json or synapse.events.USE_FROZEN_DICTS: - # canonicaljson already encodes to bytes - json_bytes = encode_canonical_json(json_object) + encoder = iterencode_canonical_json else: - json_bytes = json.dumps(json_object).encode("utf-8") - - return respond_with_json_bytes( - request, - code, - json_bytes, - send_cors=send_cors, - response_code_message=response_code_message, - ) + encoder = _encode_json_bytes + + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"application/json") + request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") + + if send_cors: + set_cors_headers(request) + + _ByteProducer(request, encoder(json_object)) + return NOT_DONE_YET def respond_with_json_bytes( - request, code, json_bytes, send_cors=False, response_code_message=None + request: Request, code: int, json_bytes: bytes, send_cors: bool = False, ): """Sends encoded JSON in response to the given request. Args: - request (twisted.web.http.Request): The http request to respond to. - code (int): The HTTP response code. - json_bytes (bytes): The json bytes to use as the response body. - send_cors (bool): Whether to send Cross-Origin Resource Sharing headers - http://www.w3.org/TR/cors/ + request: The http request to respond to. + code: The HTTP response code. + json_bytes: The json bytes to use as the response body. + send_cors: Whether to send Cross-Origin Resource Sharing headers + https://fetch.spec.whatwg.org/#http-cors-protocol + Returns: - twisted.web.server.NOT_DONE_YET""" + twisted.web.server.NOT_DONE_YET if the request is still active. + """ - request.setResponseCode(code, message=response_code_message) + request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") @@ -542,8 +660,8 @@ def respond_with_json_bytes( if send_cors: set_cors_headers(request) - # todo: we can almost certainly avoid this copy and encode the json straight into - # the bytesIO, but it would involve faffing around with string->bytes wrappers. + # note that this is zero-copy (the bytesio shares a copy-on-write buffer with + # the original `bytes`). bytes_io = BytesIO(json_bytes) producer = NoRangeStaticProducer(request, bytes_io) @@ -551,16 +669,16 @@ def respond_with_json_bytes( return NOT_DONE_YET -def set_cors_headers(request): - """Set the CORs headers so that javascript running in a web browsers can +def set_cors_headers(request: Request): + """Set the CORS headers so that javascript running in a web browsers can use this API Args: - request (twisted.web.http.Request): The http request to add CORs to. + request: The http request to add CORS to. """ request.setHeader(b"Access-Control-Allow-Origin", b"*") request.setHeader( - b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS" + b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) request.setHeader( b"Access-Control-Allow-Headers", @@ -568,7 +686,60 @@ def set_cors_headers(request): ) -def finish_request(request): +def respond_with_html(request: Request, code: int, html: str): + """ + Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. + """ + respond_with_html_bytes(request, code, html.encode("utf-8")) + + +def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): + """ + Sends HTML (encoded as UTF-8 bytes) as the response to the given request. + + Note that this adds clickjacking protection headers and finishes the request. + + Args: + request: The http request to respond to. + code: The HTTP response code. + html_bytes: The HTML bytes to use as the response body. + """ + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + # Ensure this content cannot be embedded. + set_clickjacking_protection_headers(request) + + request.write(html_bytes) + finish_request(request) + + +def set_clickjacking_protection_headers(request: Request): + """ + Set headers to guard against clickjacking of embedded content. + + This sets the X-Frame-Options and Content-Security-Policy headers which instructs + browsers to not allow the HTML of the response to be embedded onto another + page. + + Args: + request: The http request to add the headers to. + """ + request.setHeader(b"X-Frame-Options", b"DENY") + request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") + + +def finish_request(request: Request): """ Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the @@ -587,7 +758,7 @@ def finish_request(request): logger.info("Connection disconnected before response was written: %r", e) -def _request_user_agent_is_curl(request): +def _request_user_agent_is_curl(request: Request) -> bool: user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 13fcb408a6..fd90ba7828 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -17,9 +17,8 @@ import logging -from canonicaljson import json - from synapse.api.errors import Codes, SynapseError +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -214,16 +213,8 @@ def parse_json_value_from_request(request, allow_empty_body=False): if not content_bytes and allow_empty_body: return None - # Decode to Unicode so that simplejson will return Unicode strings on - # Python 2 - try: - content_unicode = content_bytes.decode("utf8") - except UnicodeDecodeError: - logger.warning("Unable to decode UTF-8") - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) - try: - content = json.loads(content_unicode) + content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) @@ -265,7 +256,7 @@ def assert_params_in_dict(body, required): raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) -class RestServlet(object): +class RestServlet: """ A Synapse REST Servlet. diff --git a/synapse/http/site.py b/synapse/http/site.py index 167293c46d..6e79b47828 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -19,6 +19,7 @@ from typing import Optional from twisted.python.failure import Failure from twisted.web.server import Request, Site +from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext @@ -145,10 +146,9 @@ class SynapseRequest(Request): Returns a context manager; the correct way to use this is: - @defer.inlineCallbacks - def handle_request(request): + async def handle_request(request): with request.processing("FooServlet"): - yield really_handle_the_request() + await really_handle_the_request() Once the context manager is closed, the completion of the request will be logged, and the various metrics will be updated. @@ -214,9 +214,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warning( - "Error processing request %r: %s %s", self, reason.type, reason.value - ) + logger.info("Connection from client lost before response was sent") if not self._is_processing: self._finished_processing() @@ -288,7 +286,9 @@ class SynapseRequest(Request): # the connection dropped) code += "!" - self.site.access_logger.info( + log_level = logging.INFO if self._should_log_request() else logging.DEBUG + self.site.access_logger.log( + log_level, "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" ' %sB %s "%s %s %s" "%s" [%d dbevts]', @@ -316,6 +316,17 @@ class SynapseRequest(Request): except Exception as e: logger.warning("Failed to stop metrics: %r", e) + def _should_log_request(self) -> bool: + """Whether we should log at INFO that we processed the request. + """ + if self.path == b"/health": + return False + + if self.method == b"OPTIONS": + return False + + return True + class XForwardedForRequest(SynapseRequest): def __init__(self, *args, **kw): @@ -350,7 +361,7 @@ class SynapseSite(Site): self, logger_name, site_tag, - config, + config: ListenerConfig, resource, server_version_string, *args, @@ -360,7 +371,8 @@ class SynapseSite(Site): self.site_tag = site_tag - proxied = config.get("x_forwarded", False) + assert config.http_options is not None + proxied = config.http_options.x_forwarded self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 7372450b45..144506c8f2 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -55,7 +55,7 @@ def stdlib_log_level_to_twisted(level: str) -> LogLevel: @attr.s @implementer(ILogObserver) -class LogContextObserver(object): +class LogContextObserver: """ An ILogObserver which adds Synapse-specific log context information. @@ -169,7 +169,7 @@ class OutputPipeType(Values): @attr.s -class DrainConfiguration(object): +class DrainConfiguration: name = attr.ib() type = attr.ib() location = attr.ib() @@ -177,7 +177,7 @@ class DrainConfiguration(object): @attr.s -class NetworkJSONTerseOptions(object): +class NetworkJSONTerseOptions: maximum_buffer = attr.ib(type=int) diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index c0b9384189..1b8916cfa2 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -152,7 +152,7 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb @attr.s @implementer(IPushProducer) -class LogProducer(object): +class LogProducer: """ An IPushProducer that writes logs from its buffer to its transport when it is resumed. @@ -190,7 +190,7 @@ class LogProducer(object): @attr.s @implementer(ILogObserver) -class TerseJSONToTCPLogObserver(object): +class TerseJSONToTCPLogObserver: """ An IObserver that writes JSON logs to a TCP target. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 8b9c4e38bd..22598e02d2 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -74,7 +74,7 @@ except Exception: get_thread_id = threading.get_ident -class ContextResourceUsage(object): +class ContextResourceUsage: """Object for tracking the resources used by a log context Attributes: @@ -179,7 +179,7 @@ class ContextResourceUsage(object): LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] -class _Sentinel(object): +class _Sentinel: """Sentinel to represent the root context""" __slots__ = ["previous_context", "finished", "request", "scope", "tag"] @@ -226,7 +226,7 @@ class _Sentinel(object): SENTINEL_CONTEXT = _Sentinel() -class LoggingContext(object): +class LoggingContext: """Additional context for log formatting. Contexts are scoped within a "with" block. @@ -566,36 +566,33 @@ class LoggingContextFilter(logging.Filter): return True -class PreserveLoggingContext(object): - """Captures the current logging context and restores it when the scope is - exited. Used to restore the context after a function using - @defer.inlineCallbacks is resumed by a callback from the reactor.""" +class PreserveLoggingContext: + """Context manager which replaces the logging context - __slots__ = ["current_context", "new_context", "has_parent"] + The previous logging context is restored on exit.""" + + __slots__ = ["_old_context", "_new_context"] def __init__( self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT ) -> None: - self.new_context = new_context + self._new_context = new_context def __enter__(self) -> None: - """Captures the current logging context""" - self.current_context = set_current_context(self.new_context) - - if self.current_context: - self.has_parent = self.current_context.previous_context is not None + self._old_context = set_current_context(self._new_context) def __exit__(self, type, value, traceback) -> None: - """Restores the current logging context""" - context = set_current_context(self.current_context) + context = set_current_context(self._old_context) - if context != self.new_context: + if context != self._new_context: if not context: - logger.warning("Expected logging context %s was lost", self.new_context) + logger.warning( + "Expected logging context %s was lost", self._new_context + ) else: logger.warning( "Expected logging context %s but found %s", - self.new_context, + self._new_context, context, ) diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index fbf570c756..d736ad5b9b 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -16,8 +16,7 @@ import logging import traceback - -from six import StringIO +from io import StringIO class LogFormatter(logging.Formatter): diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5dddf57008..7df0aa197d 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -164,28 +164,28 @@ Gotchas than one caller? Will all of those calling functions have be in a context with an active span? """ - import contextlib import inspect import logging import re -import types from functools import wraps -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Optional, Type -from canonicaljson import json +import attr from twisted.internet import defer from synapse.config import ConfigError +from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: + from synapse.http.site import SynapseRequest from synapse.server import HomeServer # Helper class -class _DummyTagNames(object): +class _DummyTagNames: """wrapper of opentracings tags. We need to have them if we want to reference them without opentracing around. Clearly they should never actually show up in a trace. `set_tags` overwrites @@ -226,12 +226,37 @@ except ImportError: tags = _DummyTagNames try: from jaeger_client import Config as JaegerConfig + from synapse.logging.scopecontextmanager import LogContextScopeManager except ImportError: JaegerConfig = None # type: ignore LogContextScopeManager = None # type: ignore +try: + from rust_python_jaeger_reporter import Reporter + + @attr.s(slots=True, frozen=True) + class _WrappedRustReporter: + """Wrap the reporter to ensure `report_span` never throws. + """ + + _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) + + def set_process(self, *args, **kwargs): + return self._reporter.set_process(*args, **kwargs) + + def report_span(self, span): + try: + return self._reporter.report_span(span) + except Exception: + logger.exception("Failed to report span") + + RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]] +except ImportError: + RustReporter = None + + logger = logging.getLogger(__name__) @@ -320,11 +345,19 @@ def init_tracer(hs: "HomeServer"): set_homeserver_whitelist(hs.config.opentracer_whitelist) - JaegerConfig( + config = JaegerConfig( config=hs.config.jaeger_config, service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), scope_manager=LogContextScopeManager(hs.config), - ).initialize_tracer() + ) + + # If we have the rust jaeger reporter available let's use that. + if RustReporter: + logger.info("Using rust_python_jaeger_reporter library") + tracer = config.create_tracer(RustReporter(), config.sampler) + opentracing.set_global_tracer(tracer) + else: + config.initialize_tracer() # Whitelisting @@ -466,7 +499,9 @@ def start_active_span_from_edu( if opentracing is None: return _noop_context_manager() - carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {}) + carrier = json_decoder.decode(edu_content.get("context", "{}")).get( + "opentracing", {} + ) context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) _references = [ opentracing.child_of(span_context_from_string(x)) @@ -657,7 +692,7 @@ def active_span_context_as_string(): opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier ) - return json.dumps(carrier) + return json_encoder.encode(carrier) @only_if_tracing @@ -666,7 +701,7 @@ def span_context_from_string(carrier): Returns: The active span context decoded from a string. """ - carrier = json.loads(carrier) + carrier = json_decoder.decode(carrier) return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) @@ -700,37 +735,43 @@ def trace(func=None, opname=None): _opname = opname if opname else func.__name__ - @wraps(func) - def _trace_inner(*args, **kwargs): - if opentracing is None: - return func(*args, **kwargs) + if inspect.iscoroutinefunction(func): - scope = start_active_span(_opname) - scope.__enter__() + @wraps(func) + async def _trace_inner(*args, **kwargs): + with start_active_span(_opname): + return await func(*args, **kwargs) - try: - result = func(*args, **kwargs) - if isinstance(result, defer.Deferred): + else: + # The other case here handles both sync functions and those + # decorated with inlineDeferred. + @wraps(func) + def _trace_inner(*args, **kwargs): + scope = start_active_span(_opname) + scope.__enter__() - def call_back(result): - scope.__exit__(None, None, None) - return result + try: + result = func(*args, **kwargs) + if isinstance(result, defer.Deferred): - def err_back(result): - scope.span.set_tag(tags.ERROR, True) - scope.__exit__(None, None, None) - return result + def call_back(result): + scope.__exit__(None, None, None) + return result - result.addCallbacks(call_back, err_back) + def err_back(result): + scope.__exit__(None, None, None) + return result - else: - scope.__exit__(None, None, None) + result.addCallbacks(call_back, err_back) - return result + else: + scope.__exit__(None, None, None) - except Exception as e: - scope.__exit__(type(e), None, e.__traceback__) - raise + return result + + except Exception as e: + scope.__exit__(type(e), None, e.__traceback__) + raise return _trace_inner @@ -760,48 +801,42 @@ def tag_args(func): return _tag_args_inner -def trace_servlet(servlet_name, extract_context=False): - """Decorator which traces a serlet. It starts a span with some servlet specific - tags such as the servlet_name and request information +@contextlib.contextmanager +def trace_servlet(request: "SynapseRequest", extract_context: bool = False): + """Returns a context manager which traces a request. It starts a span + with some servlet specific tags such as the request metrics name and + request information. Args: - servlet_name (str): The name to be used for the span's operation_name - extract_context (bool): Whether to attempt to extract the opentracing + request + extract_context: Whether to attempt to extract the opentracing context from the request the servlet is handling. - """ - def _trace_servlet_inner_1(func): - if not opentracing: - return func - - @wraps(func) - async def _trace_servlet_inner(request, *args, **kwargs): - request_tags = { - "request_id": request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - } - - if extract_context: - scope = start_active_span_from_request( - request, servlet_name, tags=request_tags - ) - else: - scope = start_active_span(servlet_name, tags=request_tags) - - with scope: - result = func(request, *args, **kwargs) + if opentracing is None: + yield + return - if not isinstance(result, (types.CoroutineType, defer.Deferred)): - # Some servlets aren't async and just return results - # directly, so we handle that here. - return result + request_tags = { + "request_id": request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + } - return await result + request_name = request.request_metrics.name + if extract_context: + scope = start_active_span_from_request(request, request_name, tags=request_tags) + else: + scope = start_active_span(request_name, tags=request_tags) - return _trace_servlet_inner + with scope: + try: + yield + finally: + # We set the operation name again in case its changed (which happens + # with JsonResource). + scope.span.set_operation_name(request.request_metrics.name) - return _trace_servlet_inner_1 + scope.span.set_tag("request_tag", request.request_metrics.start_context.tag) diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index dc3ab00cbb..026854b4c7 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -116,6 +116,8 @@ class _LogContextScope(Scope): if self._enter_logcontext: self.logcontext.__enter__() + return self + def __exit__(self, type, value, traceback): if type == twisted.internet.defer._DefGen_Return: super(_LogContextScope, self).__exit__(None, None, None) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 99049bb5d8..fea774e2e5 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -14,9 +14,7 @@ # limitations under the License. -import inspect import logging -import time from functools import wraps from inspect import getcallargs @@ -74,127 +72,3 @@ def log_function(f): wrapped.__name__ = func_name return wrapped - - -def time_function(f): - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - global _TIME_FUNC_ID - id = _TIME_FUNC_ID - _TIME_FUNC_ID += 1 - - start = time.clock() - - try: - _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) - - r = f(*args, **kwargs) - finally: - end = time.clock() - _log_debug_as_f( - f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) - ) - - return r - - return wrapped - - -def trace_function(f): - func_name = f.__name__ - linenum = f.func_code.co_firstlineno - pathname = f.func_code.co_filename - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - - s = frame.f_back - - to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" - % (pathname, linenum, func_name, args, kwargs) - ] - while s: - if True or s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_print.append( - "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print) - - record = logging.LogRecord( - name=name, - level=level, - pathname=pathname, - lineno=lineno, - msg=msg, - args=(), - exc_info=None, - ) - - logger.handle(record) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return wrapped - - -def get_previous_frames(): - - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - - s = frame.f_back.f_back - to_return = [] - while s: - if s.f_globals["__name__"].startswith("synapse"): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - to_return.append( - "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) - ) - - s = s.f_back - - return ", ".join(to_return) - - -def get_previous_frame(ignore=[]): - frame = inspect.currentframe() - if frame is None: - raise Exception("Can't get current frame!") - s = frame.f_back.f_back - - while s: - if s.f_globals["__name__"].startswith("synapse"): - if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): - filename, lineno, function, _, _ = inspect.getframeinfo(s) - args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - - return "{{ %s:%d %s - Args: %s }}" % ( - filename, - lineno, - function, - args_string, - ) - - s = s.f_back - - return None diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 9cf31f96b3..2643380d9e 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -22,8 +22,6 @@ import threading import time from typing import Callable, Dict, Iterable, Optional, Tuple, Union -import six - import attr from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import ( @@ -53,7 +51,7 @@ all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollec HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") -class RegistryProxy(object): +class RegistryProxy: @staticmethod def collect(): for metric in REGISTRY.collect(): @@ -62,7 +60,7 @@ class RegistryProxy(object): @attr.s(hash=True) -class LaterGauge(object): +class LaterGauge: name = attr.ib(type=str) desc = attr.ib(type=str) @@ -83,7 +81,7 @@ class LaterGauge(object): return if isinstance(calls, dict): - for k, v in six.iteritems(calls): + for k, v in calls.items(): g.add_metric(k, v) else: g.add_metric([], calls) @@ -102,7 +100,7 @@ class LaterGauge(object): all_gauges[self.name] = self -class InFlightGauge(object): +class InFlightGauge: """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -194,7 +192,7 @@ class InFlightGauge(object): gauge = GaugeMetricFamily( "_".join([self.name, name]), "", labels=self.labels ) - for key, metrics in six.iteritems(metrics_by_key): + for key, metrics in metrics_by_key.items(): gauge.add_metric(key, getattr(metrics, name)) yield gauge @@ -208,7 +206,7 @@ class InFlightGauge(object): @attr.s(hash=True) -class BucketCollector(object): +class BucketCollector: """ Like a Histogram, but allows buckets to be point-in-time instead of incrementally added to. @@ -271,7 +269,7 @@ class BucketCollector(object): # -class CPUMetrics(object): +class CPUMetrics: def __init__(self): ticks_per_sec = 100 try: @@ -331,7 +329,7 @@ gc_time = Histogram( ) -class GCCounts(object): +class GCCounts: def collect(self): cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): @@ -349,7 +347,7 @@ if not running_on_pypy: # -class PyPyGCStats(object): +class PyPyGCStats: def collect(self): # @stats is a pretty-printer object with __str__() returning a nice table, @@ -465,6 +463,12 @@ event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name" # finished being processed. event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"]) +event_processing_lag_by_event = Histogram( + "synapse_event_processing_lag_by_event", + "Time between an event being persisted and it being queued up to be sent to the relevant remote servers", + ["name"], +) + # Build info of the running server. build_info = Gauge( "synapse_build_info", "Build information", ["pythonversion", "version", "osversion"] @@ -478,7 +482,7 @@ build_info.labels( last_ticked = time.time() -class ReactorLastSeenMetric(object): +class ReactorLastSeenMetric: def collect(self): cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index ab7f948ed4..4304c60d56 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -208,6 +208,7 @@ class MetricsHandler(BaseHTTPRequestHandler): raise self.send_response(200) self.send_header("Content-Type", CONTENT_TYPE_LATEST) + self.send_header("Content-Length", str(len(output))) self.end_headers() self.wfile.write(output) @@ -261,4 +262,6 @@ class MetricsResource(Resource): def render_GET(self, request): request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) - return generate_latest(self.registry) + response = generate_latest(self.registry) + request.setHeader(b"Content-Length", str(len(response))) + return response diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 13785038ad..5b73463504 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import threading -from asyncio import iscoroutine from functools import wraps from typing import TYPE_CHECKING, Dict, Optional, Set @@ -105,7 +105,7 @@ _background_processes_active_since_last_scrape = set() # type: Set[_BackgroundP _bg_metrics_lock = threading.Lock() -class _Collector(object): +class _Collector: """A custom metrics collector for the background process metrics. Ensures that all of the metrics are up-to-date with any in-flight processes @@ -140,7 +140,7 @@ class _Collector(object): REGISTRY.register(_Collector()) -class _BackgroundProcess(object): +class _BackgroundProcess: def __init__(self, desc, ctx): self.desc = desc self._context = ctx @@ -166,7 +166,7 @@ class _BackgroundProcess(object): ) -def run_as_background_process(desc, func, *args, **kwargs): +def run_as_background_process(desc: str, func, *args, **kwargs): """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -175,10 +175,10 @@ def run_as_background_process(desc, func, *args, **kwargs): It returns a Deferred which completes when the function completes, but it doesn't follow the synapse logcontext rules, which makes it appropriate for passing to clock.looping_call and friends (or for firing-and-forgetting in the middle of a - normal synapse inlineCallbacks function). + normal synapse async function). Args: - desc (str): a description for this background process type + desc: a description for this background process type func: a function, which may return a Deferred or a coroutine args: positional args for func kwargs: keyword args for func @@ -187,8 +187,7 @@ def run_as_background_process(desc, func, *args, **kwargs): follow the synapse logcontext rules. """ - @defer.inlineCallbacks - def run(): + async def run(): with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -202,22 +201,21 @@ def run_as_background_process(desc, func, *args, **kwargs): try: result = func(*args, **kwargs) - # We probably don't have an ensureDeferred in our call stack to handle - # coroutine results, so we need to ensureDeferred here. - # - # But we need this check because ensureDeferred doesn't like being - # called on immediate values (as opposed to Deferreds or coroutines). - if iscoroutine(result): - result = defer.ensureDeferred(result) + if inspect.isawaitable(result): + result = await result - return (yield result) + return result except Exception: - logger.exception("Background process '%s' threw an exception", desc) + logger.exception( + "Background process '%s' threw an exception", desc, + ) finally: _background_process_in_flight_count.labels(desc).dec() with PreserveLoggingContext(): - return run() + # Note that we return a Deferred here so that it can be used in a + # looping_call and other places that expect a Deferred. + return defer.ensureDeferred(run()) def wrap_as_background_process(desc): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a7849cefa5..fcbd5378c4 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -31,7 +31,7 @@ __all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi" logger = logging.getLogger(__name__) -class ModuleApi(object): +class ModuleApi: """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. """ @@ -167,8 +167,10 @@ class ModuleApi(object): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._store.record_user_external_id( - auth_provider_id, remote_user_id, registered_user_id + return defer.ensureDeferred( + self._store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) ) def generate_short_term_login_token( @@ -194,12 +196,16 @@ class ModuleApi(object): synapse.api.errors.AuthError: the access token is invalid """ # see if the access token corresponds to a device - user_info = yield self._auth.get_user_by_access_token(access_token) + user_info = yield defer.ensureDeferred( + self._auth.get_user_by_access_token(access_token) + ) device_id = user_info.get("device_id") user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self._hs.get_device_handler().delete_device(user_id, device_id) + yield defer.ensureDeferred( + self._hs.get_device_handler().delete_device(user_id, device_id) + ) else: # no associated device. Just delete the access token. yield defer.ensureDeferred( @@ -219,7 +225,9 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.db.runInteraction(desc, func, *args, **kwargs) + return defer.ensureDeferred( + self._store.db_pool.runInteraction(desc, func, *args, **kwargs) + ) def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str diff --git a/synapse/notifier.py b/synapse/notifier.py index 87c120a59c..b7f4041306 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,18 @@ import logging from collections import namedtuple -from typing import Callable, Iterable, List, TypeVar +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter @@ -24,12 +35,14 @@ from twisted.internet import defer import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import StreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Collection, StreamToken, UserID from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -55,7 +68,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int: return n -class _NotificationListener(object): +class _NotificationListener: """ This represents a single client connection to the events stream. The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. @@ -67,7 +80,7 @@ class _NotificationListener(object): self.deferred = deferred -class _NotifierUserStream(object): +class _NotifierUserStream: """This represents a user connected to the event stream. It tracks the most recent stream token for that user. At a given point a user may have a number of streams listening for @@ -77,13 +90,19 @@ class _NotifierUserStream(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user_id, rooms, current_token, time_now_ms): + def __init__( + self, + user_id: str, + rooms: Collection[str], + current_token: StreamToken, + time_now_ms: int, + ): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token # The last token for which we should wake up any streams that have a - # token that comes before it. This gets updated everytime we get poked. + # token that comes before it. This gets updated every time we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. @@ -93,13 +112,13 @@ class _NotifierUserStream(object): with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - def notify(self, stream_key, stream_id, time_now_ms): + def notify(self, stream_key: str, stream_id: int, time_now_ms: int): """Notify any listeners for this user of a new event from an event source. Args: - stream_key(str): The stream the event came from. - stream_id(str): The new id for the stream the event came from. - time_now_ms(int): The current time in milliseconds. + stream_key: The stream the event came from. + stream_id: The new id for the stream the event came from. + time_now_ms: The current time in milliseconds. """ self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token @@ -112,7 +131,7 @@ class _NotifierUserStream(object): self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token) - def remove(self, notifier): + def remove(self, notifier: "Notifier"): """ Remove this listener from all the indexes in the Notifier it knows about. """ @@ -123,10 +142,10 @@ class _NotifierUserStream(object): notifier.user_to_user_stream.pop(self.user_id) - def count_listeners(self): + def count_listeners(self) -> int: return len(self.notify_deferred.observers()) - def new_listener(self, token): + def new_listener(self, token: StreamToken) -> _NotificationListener: """Returns a deferred that is resolved when there is a new token greater than the given token. @@ -149,7 +168,7 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): __bool__ = __nonzero__ # python3 -class Notifier(object): +class Notifier: """ This class is responsible for notifying any listeners when there are new events available for it. @@ -159,14 +178,16 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs: "synapse.server.HomeServer"): - self.user_to_user_stream = {} - self.room_to_user_streams = {} + self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] + self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = [] + self.pending_new_room_events = ( + [] + ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -178,10 +199,9 @@ class Notifier(object): self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + self.federation_sender = None if hs.should_send_federation(): self.federation_sender = hs.get_federation_sender() - else: - self.federation_sender = None self.state_handler = hs.get_state_handler() @@ -193,12 +213,12 @@ class Notifier(object): # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_user_streams = set() + all_user_streams = set() # type: Set[_NotifierUserStream] - for x in list(self.room_to_user_streams.values()): - all_user_streams |= x - for x in list(self.user_to_user_stream.values()): - all_user_streams.add(x) + for streams in list(self.room_to_user_streams.values()): + all_user_streams |= streams + for stream in list(self.user_to_user_stream.values()): + all_user_streams.add(stream) return sum(stream.count_listeners() for stream in all_user_streams) @@ -223,7 +243,11 @@ class Notifier(object): self.replication_callbacks.append(cb) def on_new_room_event( - self, event, room_stream_id, max_room_stream_id, extra_users=[] + self, + event: EventBase, + room_stream_id: int, + max_room_stream_id: int, + extra_users: Collection[Union[str, UserID]] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -241,11 +265,11 @@ class Notifier(object): self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id): + def _notify_pending_new_room_events(self, max_room_stream_id: int): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id(int): The highest stream_id below which all + max_room_stream_id: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -258,7 +282,12 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) - def _on_new_room_event(self, event, room_stream_id, extra_users=[]): + def _on_new_room_event( + self, + event: EventBase, + room_stream_id: int, + extra_users: Collection[Union[str, UserID]] = [], + ): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( @@ -275,13 +304,19 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - async def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id: int): try: await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") - def on_new_event(self, stream_key, new_token, users=[], rooms=[]): + def on_new_event( + self, + stream_key: str, + new_token: int, + users: Collection[Union[str, UserID]] = [], + rooms: Collection[str] = [], + ): """ Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. @@ -307,20 +342,25 @@ class Notifier(object): self.notify_replication() - def on_new_replication_data(self): + def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" self.notify_replication() async def wait_for_events( - self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START - ): + self, + user_id: str, + timeout: int, + callback: Callable[[StreamToken, StreamToken], Awaitable[T]], + room_ids=None, + from_token=StreamToken.START, + ) -> T: """Wait until the callback returns a non empty response or the timeout fires. """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = await self.event_sources.get_current_token() + current_token = self.event_sources.get_current_token() if room_ids is None: room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( @@ -377,19 +417,16 @@ class Notifier(object): async def get_events_for( self, - user, - pagination_config, - timeout, - only_keys=None, - is_guest=False, - explicit_room_id=None, - ): + user: UserID, + pagination_config: PaginationConfig, + timeout: int, + is_guest: bool = False, + explicit_room_id: str = None, + ) -> EventStreamResult: """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. - If `only_keys` is not None, events from keys will be sent down. - If explicit_room_id is not set, the user's joined rooms will be polled for events. If explicit_room_id is set, that room will be polled for events only if @@ -397,18 +434,20 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = await self.event_sources.get_current_token() + from_token = self.event_sources.get_current_token() limit = pagination_config.limit room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - async def check_for_updates(before_token, after_token): + async def check_for_updates( + before_token: StreamToken, after_token: StreamToken + ) -> EventStreamResult: if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) - events = [] + events = [] # type: List[EventBase] end_token = from_token for name, source in self.event_sources.sources.items(): @@ -417,8 +456,6 @@ class Notifier(object): after_id = getattr(after_token, keyname) if before_id == after_id: continue - if only_keys and name not in only_keys: - continue new_events, new_key = await source.get_new_events( user=user, @@ -476,7 +513,9 @@ class Notifier(object): return result - async def _get_room_ids(self, user, explicit_room_id): + async def _get_room_ids( + self, user: UserID, explicit_room_id: Optional[str] + ) -> Tuple[Collection[str], bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: @@ -486,7 +525,7 @@ class Notifier(object): raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - async def _is_world_readable(self, room_id): + async def _is_world_readable(self, room_id: str) -> bool: state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) @@ -496,7 +535,7 @@ class Notifier(object): return False @log_function - def remove_expired_streams(self): + def remove_expired_streams(self) -> None: time_now_ms = self.clock.time_msec() expired_streams = [] expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS @@ -510,21 +549,21 @@ class Notifier(object): expired_stream.remove(self) @log_function - def _register_with_keys(self, user_stream): + def _register_with_keys(self, user_stream: _NotifierUserStream): self.user_to_user_stream[user_stream.user_id] = user_stream for room in user_stream.rooms: s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - def _user_joined_room(self, user_id, room_id): + def _user_joined_room(self, user_id: str, room_id: str): new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) - def notify_replication(self): + def notify_replication(self) -> None: """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 1ffd5e2df3..fabc9ba126 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.metrics import Measure from .bulk_push_rule_evaluator import BulkPushRuleEvaluator @@ -24,7 +22,7 @@ from .bulk_push_rule_evaluator import BulkPushRuleEvaluator logger = logging.getLogger(__name__) -class ActionGenerator(object): +class ActionGenerator: def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() @@ -37,7 +35,6 @@ class ActionGenerator(object): # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user(event, context) + await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 286374d0b5..8047873ff1 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,11 +19,13 @@ import copy from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP -def list_with_base_rules(rawrules): +def list_with_base_rules(rawrules, use_new_defaults=False): """Combine the list of rules set by the user with the default push rules Args: rawrules(list): The rules the user has modified or set. + use_new_defaults(bool): Whether to use the new experimental default rules when + appending or prepending default rules. Returns: A new list with the rules set by the user combined with the defaults. @@ -43,7 +45,9 @@ def list_with_base_rules(rawrules): ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) @@ -54,6 +58,7 @@ def list_with_base_rules(rawrules): make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, + use_new_defaults, ) ) current_prio_class -= 1 @@ -62,6 +67,7 @@ def list_with_base_rules(rawrules): make_base_prepend_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, + use_new_defaults, ) ) @@ -70,27 +76,39 @@ def list_with_base_rules(rawrules): while current_prio_class > 0: ruleslist.extend( make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) return ruleslist -def make_base_append_rules(kind, modified_base_rules): +def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = BASE_APPEND_OVERRIDE_RULES + rules = ( + NEW_APPEND_OVERRIDE_RULES + if use_new_defaults + else BASE_APPEND_OVERRIDE_RULES + ) elif kind == "underride": - rules = BASE_APPEND_UNDERRIDE_RULES + rules = ( + NEW_APPEND_UNDERRIDE_RULES + if use_new_defaults + else BASE_APPEND_UNDERRIDE_RULES + ) elif kind == "content": rules = BASE_APPEND_CONTENT_RULES @@ -105,7 +123,7 @@ def make_base_append_rules(kind, modified_base_rules): return rules -def make_base_prepend_rules(kind, modified_base_rules): +def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": @@ -270,6 +288,135 @@ BASE_APPEND_OVERRIDE_RULES = [ ] +NEW_APPEND_OVERRIDE_RULES = [ + { + "rule_id": "global/override/.m.rule.encrypted", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", + } + ], + "actions": ["notify"], + }, + { + "rule_id": "global/override/.m.rule.suppress_notices", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_suppress_notices_type", + }, + { + "kind": "event_match", + "key": "content.msgtype", + "pattern": "m.notice", + "_id": "_suppress_notices", + }, + ], + "actions": [], + }, + { + "rule_id": "global/underride/.m.rule.suppress_edits", + "conditions": [ + { + "kind": "event_match", + "key": "m.relates_to.m.rel_type", + "pattern": "m.replace", + "_id": "_suppress_edits", + } + ], + "actions": [], + }, + { + "rule_id": "global/override/.m.rule.invite_for_me", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", + }, + { + "kind": "event_match", + "key": "content.membership", + "pattern": "invite", + "_id": "_invite_member", + }, + {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, + ], + "actions": ["notify", {"set_tweak": "sound", "value": "default"}], + }, + { + "rule_id": "global/override/.m.rule.contains_display_name", + "conditions": [{"kind": "contains_display_name"}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + }, + { + "rule_id": "global/override/.m.rule.tombstone", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.tombstone", + "_id": "_tombstone", + }, + { + "kind": "event_match", + "key": "state_key", + "pattern": "", + "_id": "_tombstone_statekey", + }, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + }, + { + "rule_id": "global/override/.m.rule.roomnotif", + "conditions": [ + { + "kind": "event_match", + "key": "content.body", + "pattern": "@room", + "_id": "_roomnotif_content", + }, + { + "kind": "sender_notification_permission", + "key": "room", + "_id": "_roomnotif_pl", + }, + ], + "actions": [ + "notify", + {"set_tweak": "highlight"}, + {"set_tweak": "sound", "value": "default"}, + ], + }, + { + "rule_id": "global/override/.m.rule.call", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.call.invite", + "_id": "_call", + } + ], + "actions": ["notify", {"set_tweak": "sound", "value": "ring"}], + }, +] + + BASE_APPEND_UNDERRIDE_RULES = [ { "rule_id": "global/underride/.m.rule.call", @@ -354,6 +501,36 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +NEW_APPEND_UNDERRIDE_RULES = [ + { + "rule_id": "global/underride/.m.rule.room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, + { + "kind": "event_match", + "key": "content.body", + "pattern": "*", + "_id": "body", + }, + ], + "actions": ["notify", {"set_tweak": "sound", "value": "default"}], + }, + { + "rule_id": "global/underride/.m.rule.message", + "conditions": [ + { + "kind": "event_match", + "key": "content.body", + "pattern": "*", + "_id": "body", + }, + ], + "actions": ["notify"], + "enabled": False, + }, +] + + BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: @@ -375,3 +552,26 @@ for r in BASE_APPEND_UNDERRIDE_RULES: r["priority_class"] = PRIORITY_CLASS_MAP["underride"] r["default"] = True BASE_RULE_IDS.add(r["rule_id"]) + + +NEW_RULE_IDS = set() + +for r in BASE_APPEND_CONTENT_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["content"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in BASE_PREPEND_OVERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_APPEND_OVERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_APPEND_UNDERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["underride"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index e75d964ac8..c440f2545c 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -17,14 +17,12 @@ import logging from collections import namedtuple -from six import iteritems, itervalues - from prometheus_client import Counter -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.event_auth import get_user_power_level +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache @@ -55,7 +53,49 @@ push_rules_delta_state_cache_metric = register_cache( ) -class BulkPushRuleEvaluator(object): +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + + +class BulkPushRuleEvaluator: """Calculates the outcome of push rules for an event for all users in the room at once. """ @@ -72,8 +112,7 @@ class BulkPushRuleEvaluator(object): resizable=False, ) - @defer.inlineCallbacks - def _get_rules_for_event(self, event, context): + async def _get_rules_for_event(self, event, context): """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -81,19 +120,19 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = yield self._get_rules_for_room(room_id) + rules_for_room = await self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(event, context) + rules_by_user = await rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = yield self.store.user_has_pusher(invited) + has_pusher = await self.store.user_has_pusher(invited) if has_pusher: rules_by_user = dict(rules_by_user) - rules_by_user[invited] = yield self.store.get_push_rules_for_user( + rules_by_user[invited] = await self.store.get_push_rules_for_user( invited ) @@ -116,21 +155,20 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics, ) - @defer.inlineCallbacks - def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids() + async def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = yield self.store.get_event(pl_event_id) + pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = yield self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = yield self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} + auth_events = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -138,23 +176,22 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level - @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): - """Given an event and context, evaluate the push rules and insert the - results into the event_push_actions_staging table. - - Returns: - Deferred + async def action_for_event_by_user(self, event, context) -> None: + """Given an event and context, evaluate the push rules, check if the message + should increment the unread count, and insert the results into the + event_push_actions_staging table. """ - rules_by_user = yield self._get_rules_for_event(event, context) + count_as_unread = _should_count_as_unread(event, context) + + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context(event, context) + room_members = await self.store.get_joined_users_from_context(event, context) ( power_levels, sender_power_level, - ) = yield self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -162,12 +199,12 @@ class BulkPushRuleEvaluator(object): condition_cache = {} - for uid, rules in iteritems(rules_by_user): + for uid, rules in rules_by_user.items(): if event.sender == uid: continue if not event.is_state(): - is_ignored = yield self.store.is_ignored_by(event.sender, uid) + is_ignored = await self.store.is_ignored_by(event.sender, uid) if is_ignored: continue @@ -182,6 +219,13 @@ class BulkPushRuleEvaluator(object): if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) + if count_as_unread: + # Add an element for the current user if the event needs to be marked as + # unread, so that add_push_actions_to_staging iterates over it. + # If the event shouldn't be marked as unread but should notify the + # current user, it'll be added to the dict later. + actions_by_user[uid] = [] + for rule in rules: if "enabled" in rule and not rule["enabled"]: continue @@ -199,7 +243,9 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, count_as_unread, + ) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -222,7 +268,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): return True -class RulesForRoom(object): +class RulesForRoom: """Caches push rules for users in a room. This efficiently handles users joining/leaving the room by not invalidating @@ -276,8 +322,7 @@ class RulesForRoom(object): # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - @defer.inlineCallbacks - def get_rules(self, event, context): + async def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -288,7 +333,7 @@ class RulesForRoom(object): self.room_push_rule_cache_metrics.inc_hits() return self.rules_by_user - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() @@ -306,7 +351,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids() + current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -353,7 +398,7 @@ class RulesForRoom(object): # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) - yield self._update_rules_with_member_event_ids( + await self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: @@ -371,8 +416,7 @@ class RulesForRoom(object): ) return ret_rules_by_user - @defer.inlineCallbacks - def _update_rules_with_member_event_ids( + async def _update_rules_with_member_event_ids( self, ret_rules_by_user, member_event_ids, state_group, event ): """Update the partially filled rules_by_user dict by fetching rules for @@ -381,62 +425,47 @@ class RulesForRoom(object): Args: ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets updated with any new rules. - member_event_ids (list): List of event ids for membership events that - have happened since the last time we filled rules_by_user + member_event_ids (dict): Dict of user id to event id for membership events + that have happened since the last time we filled rules_by_user state_group: The state group we are currently computing push rules for. Used when updating the cache. """ sequence = self.sequence - rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) + rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: - for event_id in itervalues(member_event_ids): + for event_id in member_event_ids.values(): if event_id == event.event_id: members[event_id] = (event.state_key, event.membership) if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = { + user_ids = { user_id - for user_id, membership in itervalues(members) + for user_id, membership in members.values() if membership == Membership.JOIN } - logger.debug("Joined: %r", interested_in_user_ids) - - if_users_with_pushers = yield self.store.get_if_users_have_pushers( - interested_in_user_ids, on_invalidate=self.invalidate_all_cb - ) - - user_ids = { - uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher - } - - logger.debug("With pushers: %r", user_ids) - - users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( - self.room_id, on_invalidate=self.invalidate_all_cb - ) - - logger.debug("With receipts: %r", users_with_receipts) + logger.debug("Joined: %r", user_ids) - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in interested_in_user_ids: - user_ids.add(uid) + # Previously we only considered users with pushers or read receipts in that + # room. We can't do this anymore because we use push actions to calculate unread + # counts, which don't rely on the user having pushers or sent a read receipt into + # the room. Therefore we just need to filter for local users here. + user_ids = list(filter(self.is_mine_id, user_ids)) - rules_by_user = yield self.store.bulk_get_push_rules( + rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb ) ret_rules_by_user.update( - item for item in iteritems(rules_by_user) if item[0] is not None + item for item in rules_by_user.items() if item[0] is not None ) self.update_cache(sequence, members, ret_rules_by_user, state_group) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 568c13eaea..b7ea4438e0 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -45,7 +45,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 INCLUDE_ALL_UNREAD_NOTIFS = False -class EmailPusher(object): +class EmailPusher: """ A pusher that sends email notifications about events (approximately) when they happen. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index eaaa7afc91..f21fa9b659 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -17,9 +17,9 @@ import logging from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.api.constants import EventTypes from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException @@ -49,7 +49,7 @@ http_badges_failed_counter = Counter( ) -class HttpPusher(object): +class HttpPusher: INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes MAX_BACKOFF_SEC = 60 * 60 @@ -127,10 +127,11 @@ class HttpPusher(object): # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - @defer.inlineCallbacks - def _update_badge(self): - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - yield self._send_badge(badge) + async def _update_badge(self): + # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems + # to be largely redundant. perhaps we can remove it. + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + await self._send_badge(badge) def on_timer(self): self._start_processing() @@ -149,8 +150,7 @@ class HttpPusher(object): run_as_background_process("httppush.process", self._process) - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -161,7 +161,7 @@ class HttpPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -169,8 +169,7 @@ class HttpPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to @@ -178,7 +177,7 @@ class HttpPusher(object): """ fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = yield fn( + unprocessed = await fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -200,13 +199,13 @@ class HttpPusher(object): "app_display_name": self.app_display_name, }, ): - processed = yield self._process_one(push_action) + processed = await self._process_one(push_action) if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, self.user_id, @@ -221,14 +220,14 @@ class HttpPusher(object): if self.failing_since: self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) @@ -247,7 +246,7 @@ class HttpPusher(object): ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, @@ -260,7 +259,7 @@ class HttpPusher(object): return self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: @@ -273,18 +272,17 @@ class HttpPusher(object): ) break - @defer.inlineCallbacks - def _process_one(self, push_action): + async def _process_one(self, push_action): if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: return True # It's been redacted - rejected = yield self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push(event, tweaks, badge) if rejected is False: return False @@ -298,17 +296,27 @@ class HttpPusher(object): ) else: logger.info("Pushkey %s was rejected: removing", pk) - yield self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict(self, event, tweaks, badge): + priority = "low" + if ( + event.type == EventTypes.Encrypted + or tweaks.get("highlight") + or tweaks.get("sound") + ): + # HACK send our push as high priority only if it generates a sound, highlight + # or may do so (i.e. is encrypted so has unknown effects). + priority = "high" + if self.data.get("format") == "event_id_only": d = { "notification": { "event_id": event.event_id, "room_id": event.room_id, "counts": {"unread": badge}, + "prio": priority, "devices": [ { "app_id": self.app_id, @@ -321,7 +329,7 @@ class HttpPusher(object): } return d - ctx = yield push_tools.get_context_for_event( + ctx = await push_tools.get_context_for_event( self.storage, self.state_handler, event, self.user_id ) @@ -332,9 +340,8 @@ class HttpPusher(object): "room_id": event.room_id, "type": event.type, "sender": event.user_id, - "counts": { # -- we don't mark messages as read yet so - # we have no way of knowing - # Just set the badge to 1 until we have read receipts + "prio": priority, + "counts": { "unread": badge, # 'missed_calls': 2 }, @@ -364,13 +371,12 @@ class HttpPusher(object): return d - @defer.inlineCallbacks - def dispatch_push(self, event, tweaks, badge): - notification_dict = yield self._build_notification_dict(event, tweaks, badge) + async def dispatch_push(self, event, tweaks, badge): + notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] try: - resp = yield self.http_client.post_json_get_json( + resp = await self.http_client.post_json_get_json( self.url, notification_dict ) except Exception as e: @@ -387,8 +393,7 @@ class HttpPusher(object): rejected = resp["rejected"] return rejected - @defer.inlineCallbacks - def _send_badge(self, badge): + async def _send_badge(self, badge): """ Args: badge (int): number of unread messages @@ -411,7 +416,7 @@ class HttpPusher(object): } } try: - yield self.http_client.post_json_get_json(self.url, d) + await self.http_client.post_json_get_json(self.url, d) http_badges_processed_counter.inc() except Exception as e: logger.warning( diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index d57a66a697..6c57854018 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -16,18 +16,17 @@ import email.mime.multipart import email.utils import logging -import time +import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar -from six.moves import urllib - import bleach import jinja2 from synapse.api.constants import EventTypes from synapse.api.errors import StoreError +from synapse.config.emailconfig import EmailSubjectConfig from synapse.logging.context import make_deferred_yieldable from synapse.push.presentable_names import ( calculate_room_name, @@ -43,23 +42,6 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -MESSAGE_FROM_PERSON_IN_ROOM = ( - "You have a message on %(app)s from %(person)s in the %(room)s room..." -) -MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." -MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." -MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." -MESSAGES_IN_ROOM_AND_OTHERS = ( - "You have messages on %(app)s in the %(room)s room and others..." -) -MESSAGES_FROM_PERSON_AND_OTHERS = ( - "You have messages on %(app)s from %(person)s and others..." -) -INVITE_FROM_PERSON_TO_ROOM = ( - "%(person)s has invited you to join the %(room)s room on %(app)s..." -) -INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." - CONTEXT_BEFORE = 1 CONTEXT_AFTER = 1 @@ -110,7 +92,7 @@ ALLOWED_ATTRS = { # ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"] -class Mailer(object): +class Mailer: def __init__(self, hs, app_name, template_html, template_text): self.hs = hs self.template_html = template_html @@ -122,6 +104,7 @@ class Mailer(object): self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() self.app_name = app_name + self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig logger.info("Created Mailer for app_name %s" % app_name) @@ -148,7 +131,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Password Reset" % self.hs.config.server_name, + self.email_subjects.password_reset + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -175,7 +159,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Register your Email Address" % self.hs.config.server_name, + self.email_subjects.email_validation + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -203,7 +188,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Validate Your Email" % self.hs.config.server_name, + self.email_subjects.email_validation + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -270,16 +256,13 @@ class Mailer(object): user_id, app_id, email_address ), "summary_text": summary_text, - "app_name": self.app_name, "rooms": rooms, "reason": reason, } - await self.send_email( - email_address, "[%s] %s" % (self.app_name, summary_text), template_vars - ) + await self.send_email(email_address, summary_text, template_vars) - async def send_email(self, email_address, subject, template_vars): + async def send_email(self, email_address, subject, extra_template_vars): """Send an email with the given information and template text""" try: from_string = self.hs.config.email_notif_from % {"app": self.app_name} @@ -292,6 +275,13 @@ class Mailer(object): if raw_to == "": raise RuntimeError("Invalid 'to' address") + template_vars = { + "app_name": self.app_name, + "server_name": self.hs.config.server.server_name, + } + + template_vars.update(extra_template_vars) + html_text = self.template_html.render(**template_vars) html_part = MIMEText(html_text, "html", "utf8") @@ -477,12 +467,12 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - return INVITE_FROM_PERSON % { + return self.email_subjects.invite_from_person % { "person": inviter_name, "app": self.app_name, } else: - return INVITE_FROM_PERSON_TO_ROOM % { + return self.email_subjects.invite_from_person_to_room % { "person": inviter_name, "room": room_name, "app": self.app_name, @@ -500,13 +490,13 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - return MESSAGE_FROM_PERSON_IN_ROOM % { + return self.email_subjects.message_from_person_in_room % { "person": sender_name, "room": room_name, "app": self.app_name, } elif sender_name is not None: - return MESSAGE_FROM_PERSON % { + return self.email_subjects.message_from_person % { "person": sender_name, "app": self.app_name, } @@ -514,7 +504,10 @@ class Mailer(object): # There's more than one notification for this room, so just # say there are several if room_name is not None: - return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} + return self.email_subjects.messages_in_room % { + "room": room_name, + "app": self.app_name, + } else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -532,7 +525,7 @@ class Mailer(object): ] ) - return MESSAGES_FROM_PERSON % { + return self.email_subjects.messages_from_person % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, } @@ -541,7 +534,7 @@ class Mailer(object): # ...but we still refer to the 'reason' room which triggered the mail if reason["room_name"] is not None: - return MESSAGES_IN_ROOM_AND_OTHERS % { + return self.email_subjects.messages_in_room_and_others % { "room": reason["room_name"], "app": self.app_name, } @@ -561,7 +554,7 @@ class Mailer(object): [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] ) - return MESSAGES_FROM_PERSON_AND_OTHERS % { + return self.email_subjects.messages_from_person_and_others % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, } @@ -646,72 +639,3 @@ def string_ordinal_total(s): for c in s: tot += ord(c) return tot - - -def format_ts_filter(value, format): - return time.strftime(format, time.localtime(value / 1000)) - - -def load_jinja2_templates( - template_dir, - template_filenames, - apply_format_ts_filter=False, - apply_mxc_to_http_filter=False, - public_baseurl=None, -): - """Loads and returns one or more jinja2 templates and applies optional filters - - Args: - template_dir (str): The directory where templates are stored - template_filenames (list[str]): A list of template filenames - apply_format_ts_filter (bool): Whether to apply a template filter that formats - timestamps - apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts - mxc urls to http urls - public_baseurl (str|None): The public baseurl of the server. Required for - apply_mxc_to_http_filter to be enabled - - Returns: - A list of jinja2 templates corresponding to the given list of filenames, - with order preserved - """ - logger.info( - "loading email templates %s from '%s'", template_filenames, template_dir - ) - loader = jinja2.FileSystemLoader(template_dir) - env = jinja2.Environment(loader=loader) - - if apply_format_ts_filter: - env.filters["format_ts"] = format_ts_filter - - if apply_mxc_to_http_filter and public_baseurl: - env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl) - - templates = [] - for template_filename in template_filenames: - template = env.get_template(template_filename) - templates.append(template) - - return templates - - -def _create_mxc_to_http_filter(public_baseurl): - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if "#" in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - serverAndMediaId, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0644a13cfc..d8f4a453cd 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -@defer.inlineCallbacks -def calculate_room_name( +async def calculate_room_name( store, room_state_ids, user_id, @@ -53,7 +50,7 @@ def calculate_room_name( """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: - m_room_name = yield store.get_event( + m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: @@ -61,7 +58,7 @@ def calculate_room_name( # does it have a canonical alias? if (EventTypes.CanonicalAlias, "") in room_state_ids: - canon_alias = yield store.get_event( + canon_alias = await store.get_event( room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( @@ -81,7 +78,7 @@ def calculate_room_name( my_member_event = None if (EventTypes.Member, user_id) in room_state_ids: - my_member_event = yield store.get_event( + my_member_event = await store.get_event( room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) @@ -90,7 +87,7 @@ def calculate_room_name( and my_member_event.content["membership"] == "invite" ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: - inviter_member_event = yield store.get_event( + inviter_member_event = await store.get_event( room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) @@ -107,7 +104,7 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: - member_events = yield store.get_events( + member_events = await store.get_events( list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 11032491af..709ace01e5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,9 +16,7 @@ import logging import re -from typing import Pattern - -from six import string_types +from typing import Any, Dict, List, Pattern, Union from synapse.events import EventBase from synapse.types import UserID @@ -74,17 +72,40 @@ def _test_ineq_condition(condition, number): return False -def tweaks_for_actions(actions): +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ tweaks = {} for a in actions: if not isinstance(a, dict): continue - if "set_tweak" in a and "value" in a: - tweaks[a["set_tweak"]] = a["value"] + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) return tweaks -class PushRuleEvaluatorForEvent(object): +class PushRuleEvaluatorForEvent: def __init__( self, event: EventBase, @@ -131,7 +152,7 @@ class PushRuleEvaluatorForEvent(object): # XXX: optimisation: cache our pattern regexps if condition["key"] == "content.body": body = self._event.content.get("body", None) - if not body: + if not body or not isinstance(body, str): return False return _glob_matches(pattern, body, word_boundary=True) @@ -147,7 +168,7 @@ class PushRuleEvaluatorForEvent(object): return False body = self._event.content.get("body", None) - if not body: + if not body or not isinstance(body, str): return False # Similar to _glob_matches, but do not treat display_name as a glob. @@ -244,7 +265,7 @@ def _flatten_dict(d, prefix=[], result=None): if result is None: result = {} for key, value in d.items(): - if isinstance(value, string_types): + if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() elif hasattr(value, "items"): _flatten_dict(value, prefix=(prefix + [key]), result=result) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 5dae4648c0..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage -@defer.inlineCallbacks -def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_local_user(user_id) - joins = yield store.get_rooms_for_user(user_id) +async def get_badge_count(store, user_id): + invites = await store.get_invited_rooms_for_local_user(user_id) + joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -32,7 +29,7 @@ def get_badge_count(store, user_id): if room_id in my_receipts_by_room: last_unread_event_id = my_receipts_by_room[room_id] - notifs = yield ( + notifs = await ( store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, last_unread_event_id ) @@ -43,23 +40,22 @@ def get_badge_count(store, user_id): return badge -@defer.inlineCallbacks -def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) + room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = yield calculate_room_name( + name = await calculate_room_name( storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield storage.main.get_event(sender_state_event_id) + sender_state_event = await storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8ad0bf5936..2a52e226e3 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -15,24 +15,15 @@ import logging +from synapse.push.emailpusher import EmailPusher +from synapse.push.mailer import Mailer + from .httppusher import HttpPusher logger = logging.getLogger(__name__) -# We try importing this if we can (it will fail if we don't -# have the optional email dependencies installed). We don't -# yet have the config to know if we need the email pusher, -# but importing this after daemonizing seems to fail -# (even though a simple test of importing from a daemonized -# process works fine) -try: - from synapse.push.emailpusher import EmailPusher - from synapse.push.mailer import Mailer, load_jinja2_templates -except Exception: - pass - -class PusherFactory(object): +class PusherFactory: def __init__(self, hs): self.hs = hs self.config = hs.config @@ -43,16 +34,8 @@ class PusherFactory(object): if hs.config.email_enable_notifs: self.mailers = {} # app_name -> Mailer - self.notif_template_html, self.notif_template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_notif_template_html, - self.config.email_notif_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) + self._notif_template_html = hs.config.email_notif_template_html + self._notif_template_text = hs.config.email_notif_template_text self.pusher_types["email"] = self._create_email_pusher @@ -73,8 +56,8 @@ class PusherFactory(object): mailer = Mailer( hs=self.hs, app_name=app_name, - template_html=self.notif_template_html, - template_text=self.notif_template_text, + template_html=self._notif_template_html, + template_text=self._notif_template_text, ) self.mailers[app_name] = mailer return EmailPusher(self.hs, pusherdict, mailer) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 88d203aa44..3c3262a88c 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,13 +15,10 @@ # limitations under the License. import logging -from collections import defaultdict -from threading import Lock -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Dict, Union -from twisted.internet import defer +from prometheus_client import Gauge -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -29,9 +26,18 @@ from synapse.push.httppusher import HttpPusher from synapse.push.pusher import PusherFactory from synapse.util.async_helpers import concurrently_execute +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) +synapse_pushers = Gauge( + "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"] +) + + class PusherPool: """ The pusher pool. This is responsible for dispatching notifications of new events to @@ -44,39 +50,23 @@ class PusherPool: Note that it is expected that each pusher will have its own 'processing' loop which will send out the notifications in the background, rather than blocking until the notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and - Pusher.on_new_receipts are not expected to return deferreds. + Pusher.on_new_receipts are not expected to return awaitables. """ - def __init__(self, _hs): - self.hs = _hs - self.pusher_factory = PusherFactory(_hs) - self._should_start_pushers = _hs.config.start_pushers + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.pusher_factory = PusherFactory(hs) + self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + # We shard the handling of push notifications by user ID. + self._pusher_shard_config = hs.config.push.pusher_shard_config + self._instance_name = hs.get_instance_name() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] - # a lock for the pushers dict, since `count_pushers` is called from an different - # and we otherwise get concurrent modification errors - self._pushers_lock = Lock() - - def count_pushers(): - results = defaultdict(int) # type: Dict[Tuple[str, str], int] - with self._pushers_lock: - for pushers in self.pushers.values(): - for pusher in pushers.values(): - k = (type(pusher).__name__, pusher.app_id) - results[k] += 1 - return results - - LaterGauge( - name="synapse_pushers", - desc="the number of active pushers", - labels=["kind", "app_id"], - caller=count_pushers, - ) - def start(self): """Starts the pushers off in a background process. """ @@ -85,8 +75,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -102,8 +91,9 @@ class PusherPool: """Creates a new pusher and adds it to the pool Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ + time_now_msec = self.clock.time_msec() # we try to create the pusher just to validate the config: it @@ -131,9 +121,9 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() + last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - yield self.store.add_pusher( + await self.store.add_pusher( user_id=user_id, access_token=access_token, kind=kind, @@ -147,15 +137,14 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) return pusher - @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user( + async def remove_pushers_by_app_id_and_pushkey_not_user( self, app_id, pushkey, not_user_id ): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: if p["user_name"] != not_user_id: logger.info( @@ -164,10 +153,9 @@ class PusherPool: pushkey, p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token(self, user_id, access_tokens): """Remove the pushers for a given user corresponding to a set of access_tokens. @@ -176,8 +164,11 @@ class PusherPool: access_tokens (Iterable[int]): access token *ids* to remove pushers for """ + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + tokens = set(access_tokens) - for p in (yield self.store.get_pushers_by_user_id(user_id)): + for p in await self.store.get_pushers_by_user_id(user_id): if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", @@ -185,16 +176,15 @@ class PusherPool: p["pushkey"], p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return try: - users_affected = yield self.store.get_push_action_users_in_range( + users_affected = await self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) @@ -206,8 +196,7 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): if not self.pushers: # nothing to do here. return @@ -215,11 +204,9 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - updated_receipts = yield self.store.get_all_updated_receipts( + users_affected = await self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) - # This returns a tuple, user_id is at index 3 - users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: @@ -229,17 +216,19 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - @defer.inlineCallbacks - def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id(self, app_id, pushkey, user_id): """Look up the details for the given pusher, and start it Returns: - Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + EmailPusher|HttpPusher|None: The pusher started, if any """ if not self._should_start_pushers: return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: @@ -248,35 +237,35 @@ class PusherPool: pusher = None if pusher_dict: - pusher = yield self._start_pusher(pusher_dict) + pusher = await self._start_pusher(pusher_dict) return pusher - @defer.inlineCallbacks - def _start_pushers(self): + async def _start_pushers(self) -> None: """Start all the pushers - - Returns: - Deferred """ - pushers = yield self.store.get_all_pushers() + pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. - yield concurrently_execute(self._start_pusher, pushers, 10) + await concurrently_execute(self._start_pusher, pushers, 10) logger.info("Started pushers") - @defer.inlineCallbacks - def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusherdict): """Start the given pusher Args: pusherdict (dict): dict with the values pulled from the db table Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ + if not self._pusher_shard_config.should_handle( + self._instance_name, pusherdict["user_name"] + ): + return + try: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: @@ -300,11 +289,12 @@ class PusherPool: appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) - with self._pushers_lock: - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + + synapse_pushers.labels(type(p).__name__, p.app_id).inc() # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to @@ -312,7 +302,7 @@ class PusherPool: user_id = pusherdict["user_name"] last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: - have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering ) else: @@ -324,18 +314,18 @@ class PusherPool: return p - @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) if appid_pushkey in byuser: logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - byuser[appid_pushkey].on_stop() - with self._pushers_lock: - del byuser[appid_pushkey] + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8b4312e5a3..2d995ec456 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -43,7 +43,7 @@ REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.1.3", + "canonicaljson>=1.3.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", @@ -59,18 +59,17 @@ REQUIREMENTS = [ "pyyaml>=3.11", "pyasn1>=0.1.9", "pyasn1-modules>=0.0.7", - "daemonize>=2.3.1", "bcrypt>=3.1.0", "pillow>=4.3.0", "sortedcontainers>=1.4.4", "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", - "six>=1.10", - "prometheus_client>=0.0.18,<0.8.0", - # we use attr.s(slots), which arrived in 16.0.0 - # Twisted 18.7.0 requires attrs>=17.4.0 - "attrs>=17.4.0", + "prometheus_client>=0.0.18,<0.9.0", + # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: + # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 + # is out in November.) + "attrs>=19.1.0", "netaddr>=0.7.18", "Jinja2>=2.9", "bleach>=1.4.3", @@ -81,8 +80,6 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 - "resources.consent": ["Jinja2>=2.9"], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ @@ -95,7 +92,12 @@ CONDITIONAL_REQUIREMENTS = { "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], - "test": ["mock>=2.0", "parameterized"], + # Dependencies which are exclusively required by unit test code. This is + # NOT a list of all modules that are necessary to run the unit tests. + # Tests assume that all optional dependencies are installed. + # + # parameterized_class decorator was introduced in parameterized 0.7.0 + "test": ["mock>=2.0", "parameterized>=0.7.0"], "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 19b69e0e11..a84a064c8d 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication" class ReplicationRestResource(JsonResource): def __init__(self, hs): - JsonResource.__init__(self, hs, canonical_json=False) + # We enable extracting jaeger contexts here as these are internal APIs. + super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) def register_servlets(self, hs): @@ -38,10 +39,10 @@ class ReplicationRestResource(JsonResource): federation.register_servlets(hs, self) presence.register_servlets(hs, self) membership.register_servlets(hs, self) + streams.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) - streams.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 793cef6c26..ba16f22c91 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,32 +16,24 @@ import abc import logging import re +import urllib from inspect import signature from typing import Dict, List, Tuple -from six import raise_from -from six.moves import urllib - -from twisted.internet import defer - from synapse.api.errors import ( CodeMessageException, HttpResponseException, RequestSendFailed, SynapseError, ) -from synapse.logging.opentracing import ( - inject_active_span_byte_dict, - trace, - trace_servlet, -) +from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class ReplicationEndpoint(object): +class ReplicationEndpoint: """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` @@ -98,16 +90,16 @@ class ReplicationEndpoint(object): # assert here that sub classes don't try and use the name. assert ( "instance_name" not in self.PATH_ARGS - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert ( "instance_name" not in signature(self.__class__._serialize_payload).parameters - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod - def _serialize_payload(**kwargs): + async def _serialize_payload(**kwargs): """Static method that is called when creating a request. Concrete implementations should have explicit parameters (rather than @@ -116,9 +108,8 @@ class ReplicationEndpoint(object): argument list. Returns: - Deferred[dict]|dict: If POST/PUT request then dictionary must be - JSON serialisable, otherwise must be appropriate for adding as - query args. + dict: If POST/PUT request then dictionary must be JSON serialisable, + otherwise must be appropriate for adding as query args. """ return {} @@ -150,8 +141,7 @@ class ReplicationEndpoint(object): instance_map = hs.config.worker.instance_map @trace(opname="outgoing_replication_request") - @defer.inlineCallbacks - def send_request(instance_name="master", **kwargs): + async def send_request(instance_name="master", **kwargs): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") if instance_name == "master": @@ -165,7 +155,7 @@ class ReplicationEndpoint(object): "Instance %r not in 'instance_map' config" % (instance_name,) ) - data = yield cls._serialize_payload(**kwargs) + data = await cls._serialize_payload(**kwargs) url_args = [ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS @@ -203,7 +193,7 @@ class ReplicationEndpoint(object): headers = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers, None, check_destination=False) try: - result = yield request_func(uri, data, headers=headers) + result = await request_func(uri, data, headers=headers) break except CodeMessageException as e: if e.code != 504 or not cls.RETRY_ON_TIMEOUT: @@ -213,14 +203,14 @@ class ReplicationEndpoint(object): # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. - yield clock.sleep(1) + await clock.sleep(1) except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) raise e.to_synapse_error() except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to talk to master"), e) + raise SynapseError(502, "Failed to talk to master") from e return result @@ -242,11 +232,8 @@ class ReplicationEndpoint(object): args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) - handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler) - # We don't let register paths trace this servlet using the default tracing - # options because we wish to extract the context explicitly. http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, trace=False + method, [pattern], handler, self.__class__.__name__, ) def _cached_handler(self, request, txn_id, **kwargs): diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index e32aac0a25..20f3ba76c0 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(user_id): + async def _serialize_payload(user_id): return {} async def _handle_request(self, request, user_id): diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index c287c4e269..6b56315148 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext @@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - @defer.inlineCallbacks - def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, event_and_contexts, backfilled): """ Args: store @@ -78,7 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """ event_payloads = [] for event, context in event_and_contexts: - serialized_context = yield context.serialize(event, store) + serialized_context = await context.serialize(event, store) event_payloads.append( { @@ -154,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - def _serialize_payload(edu_type, origin, content): + async def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} async def _handle_request(self, request, edu_type): @@ -197,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - def _serialize_payload(query_type, args): + async def _serialize_payload(query_type, args): """ Args: query_type (str) @@ -238,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - def _serialize_payload(room_id, args): + async def _serialize_payload(room_id, args): """ Args: room_id (str) @@ -273,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - def _serialize_payload(room_id, room_version): + async def _serialize_payload(room_id, room_version): return {"room_version": room_version.identifier} async def _handle_request(self, request, room_id): diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 798b9d3af5..fb326bb869 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, device_id, initial_display_name, is_guest): + async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): """ Args: device_id (str|None): Device ID to use, if None a new one is diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index a7174c4a8f..741329ab5f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -14,11 +14,11 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room if TYPE_CHECKING: @@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): + async def _serialize_payload( + requester, room_id, user_id, remote_room_hosts, content + ): """ Args: requester(Requester) @@ -88,49 +90,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): - """Rejects the invite for the user and room. + """Rejects an out-of-band invite we have received from a remote server Request format: - POST /_synapse/replication/remote_reject_invite/:room_id/:user_id + POST /_synapse/replication/remote_reject_invite/:event_id { + "txn_id": ..., "requester": ..., - "remote_room_hosts": [...], "content": { ... } } """ NAME = "remote_reject_invite" - PATH_ARGS = ("room_id", "user_id") + PATH_ARGS = ("invite_event_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) - self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): + async def _serialize_payload( # type: ignore + invite_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ): """ Args: - requester(Requester) - room_id (str) - user_id (str) - remote_room_hosts (list[str]): Servers to try and reject via + invite_event_id: ID of the invite to be rejected + txn_id: optional transaction ID supplied by the client + requester: user making the rejection request, according to the access token + content: additional content to include in the rejection event. + Normally an empty dict. """ return { + "txn_id": txn_id, "requester": requester.serialize(), - "remote_room_hosts": remote_room_hosts, "content": content, } - async def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, invite_event_id): content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] + txn_id = content["txn_id"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -138,60 +145,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) - - try: - event, stream_id = await self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, user_id, event_content, - ) - event_id = event.event_id - except Exception as e: - # if we were unable to reject the exception, just mark - # it as rejected on our end and plough ahead. - # - # The 'except' clause is very broad, but we need to - # capture everything from DNS failures upwards - # - logger.warning("Failed to reject invite: %s", e) - - stream_id = await self.member_handler.locally_reject_invite( - user_id, room_id - ) - event_id = None + # hopefully we're now on the master, so this won't recurse! + event_id, stream_id = await self.member_handler.remote_reject_invite( + invite_event_id, txn_id, requester, event_content, + ) return 200, {"event_id": event_id, "stream_id": stream_id} -class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint): - """Rejects the invite for the user and room locally. - - Request format: - - POST /_synapse/replication/locally_reject_invite/:room_id/:user_id - - {} - """ - - NAME = "locally_reject_invite" - PATH_ARGS = ("room_id", "user_id") - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - self.member_handler = hs.get_room_member_handler() - - @staticmethod - def _serialize_payload(room_id, user_id): - return {} - - async def _handle_request(self, request, room_id, user_id): - logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id) - - stream_id = await self.member_handler.locally_reject_invite(user_id, room_id) - - return 200, {"stream_id": stream_id} - - class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): """Notifies that a user has joined or left the room @@ -215,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - def _serialize_payload(room_id, user_id, change): + async def _serialize_payload(room_id, user_id, change): """ Args: room_id (str) @@ -245,4 +206,3 @@ def register_servlets(hs, http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) - ReplicationLocallyRejectInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index ea1b33331b..bc9aa82cb4 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - def _serialize_payload(user_id): + async def _serialize_payload(user_id): return {} async def _handle_request(self, request, user_id): @@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - def _serialize_payload(user_id, state, ignore_status_msg=False): + async def _serialize_payload(user_id, state, ignore_status_msg=False): return { "state": state, "ignore_status_msg": ignore_status_msg, diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 0c4aca1291..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload( + async def _serialize_payload( user_id, password_hash, was_guest, @@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} @@ -105,7 +109,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, auth_result, access_token): + async def _serialize_payload(user_id, auth_result, access_token): """ Args: user_id (str): The user ID that consented diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index c981723c1a..f13d452426 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext @@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - @defer.inlineCallbacks - def _serialize_payload( + async def _serialize_payload( event_id, store, event, context, requester, ratelimit, extra_users ): """ @@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = yield context.serialize(event, store) + serialized_context = await context.serialize(event, store) payload = { "event": event.get_pdu_json(), diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index bde97eef32..309159e304 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): self.streams = hs.get_replication_streams() @staticmethod - def _serialize_payload(stream_name, from_token, upto_token): + async def _serialize_payload(stream_name, from_token, upto_token): return {"from_token": from_token, "upto_token": upto_token} async def _handle_request(self, request, stream_name): diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f9e2533e96..60f2e1245f 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -16,8 +16,8 @@ import logging from typing import Optional -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 9d1d173b2f..eb74903d68 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -16,14 +16,14 @@ from synapse.storage.util.id_generators import _load_current_id -class SlavedIdTracker(object): +class SlavedIdTracker: def __init__(self, db_conn, table, column, extra_tables=[], step=1): self.step = step self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self.advance(_load_current_id(db_conn, table, column)) + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, new_id): + def advance(self, instance_name, new_id): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): @@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 9db6c62bc7..bb66ba9b80 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore -from synapse.storage.data_stores.main.tags import TagsWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data", @@ -39,13 +40,13 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved return self._account_data_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "tag_account_data": - self._account_data_id_gen.advance(token) + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == "account_data": - self._account_data_id_gen.advance(token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index a67fbeffb7..0f8d7037bd 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.databases.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1a38f53dfb..a6fdedde63 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 6e7fd259d4..533d927701 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -15,17 +15,18 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import ToDeviceStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", @@ -44,8 +45,8 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "to_device": - self._device_inbox_id_gen.advance(token) + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 9d8067342f..3b788c9625 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.data_stores.main.devices import DeviceWorkerStore -from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs @@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 8b9717c46f..1945bcf9a8 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.directory import DirectoryWorkerStore +from synapse.storage.databases.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 1a1a50a24f..da1cc836cf 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,18 +15,18 @@ # limitations under the License. import logging -from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore -from synapse.storage.data_stores.main.event_push_actions import ( +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.relations import RelationsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore -from synapse.storage.data_stores.main.stream import StreamWorkerStore -from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore -from synapse.storage.database import Database +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -55,11 +55,11 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index bcb0688954..2562b6fc38 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.filtering import FilteringStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 1851e7d525..567b4a5cc1 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -15,13 +15,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import GroupServerStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs @@ -38,8 +39,8 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): return self._group_updates_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "groups": - self._group_updates_id_gen.advance(token) + if stream_name == GroupServerStream.NAME: + self._group_updates_id_gen.advance(instance_name, token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index 3def367ae9..961579751c 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.keys import KeyStore +from synapse.storage.databases.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 4e0124842d..025f6f6be8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import PresenceStream from synapse.storage import DataStore -from synapse.storage.data_stores.main.presence import PresenceStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -23,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") @@ -42,8 +43,8 @@ class SlavedPresenceStore(BaseSlavedStore): return self._presence_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "presence": - self._presence_id_gen.advance(token) + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py index 28c508aad3..f85b20a071 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.data_stores.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 6adb19463a..de904c943c 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -14,24 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushRulesStream +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_push_rules_stream_token(self): - return ( - self._push_rules_stream_id_gen.get_current_token(), - self._stream_id_gen.get_current_token(), - ) - def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "push_rules": - self._push_rules_stream_id_gen.advance(token) + # We assert this for the benefit of mypy + assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) + + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index cb78b49acb..9da218bfe8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -14,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import PushersStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.pusher import PusherWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] @@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): return self._pushers_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "pushers": - self._pushers_id_gen.advance(token) + if stream_name == PushersStream.NAME: + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index be716cc558..5c2986e050 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,23 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import ReceiptsStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -# So, um, we want to borrow a load of functions intended for reading from -# a DataStore, but we don't want to take functions that either write to the -# DataStore or are cached and don't have cache invalidation logic. -# -# Rather than write duplicate versions of those functions, or lift them to -# a common base class, we going to grab the underlying __func__ object from -# the method descriptor on the DataStore and chuck them into our class. - class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( @@ -52,8 +45,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "receipts": - self._receipts_id_gen.advance(token) + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 4b8553e250..a40f064e2b 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 8873bf37e5..80ae803ad9 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.room import RoomWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import PublicRoomsStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.room import RoomWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" @@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): return self._public_room_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "public_rooms": - self._public_room_id_gen.advance(token) + if stream_name == PublicRoomsStream.NAME: + self._public_room_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index ac88e6b8c3..2091ac0df6 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.transactions import TransactionStore +from synapse.storage.databases.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 523a1358d4..1b8718b11d 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -25,7 +25,7 @@ Structure of the module: * command.py - the definitions of all the valid commands * protocol.py - the TCP protocol classes * resource.py - handles streaming stream updates to replications - * streams/ - the definitons of all the valid streams + * streams/ - the definitions of all the valid streams The general interaction of the classes are: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index df29732f51..d6ecf5b327 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,7 +14,6 @@ # limitations under the License. """A replication client for use by synapse workers. """ -import heapq import logging from typing import TYPE_CHECKING, Dict, List, Tuple @@ -24,6 +23,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -104,6 +104,7 @@ class ReplicationDataHandler: self._clock = hs.get_clock() self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() + self._typing_handler = hs.get_typing_handler() # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. @@ -127,6 +128,12 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == TypingStream.NAME: + self._typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -211,9 +218,8 @@ class ReplicationDataHandler: waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - # We insert into the list using heapq as it is more efficient than - # pushing then resorting each time. - heapq.heappush(waiting_list, (position, deferred)) + waiting_list.append((position, deferred)) + waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c04f622816..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -19,17 +19,9 @@ allowed to be sent by which side. """ import abc import logging -import platform from typing import Tuple, Type -if platform.python_implementation() == "PyPy": - import json - - _json_encoder = json.JSONEncoder() -else: - import simplejson as json # type: ignore[no-redef] # noqa: F821 - - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -54,7 +46,7 @@ class Command(metaclass=abc.ABCMeta): @abc.abstractmethod def to_line(self) -> str: - """Serialises the comamnd for the wire. Does not include the command + """Serialises the command for the wire. Does not include the command prefix. """ @@ -131,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -140,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -149,7 +141,7 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream postition without + """Sent by the server to tell the client the stream position without needing to send an RDATA. Format:: @@ -188,7 +180,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """Sent by either side as a keep alive. The data is arbitrary (often timestamp) """ NAME = "PING" @@ -300,20 +292,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK <token> + FEDERATION_ACK <instance_name> <token> """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): @@ -363,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -371,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index cbcf46f3ae..1c303f3a46 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -13,15 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, @@ -43,8 +56,8 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -56,12 +69,16 @@ inbound_rdata_count = Counter( user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) + user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -97,6 +114,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -108,12 +133,8 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. + # Map of stream name to batched updates. See RdataCommand for info on + # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. @@ -123,9 +144,6 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming streams that are coming from that connection - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -133,6 +151,32 @@ class ReplicationCommandHandler: lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_queue + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -143,15 +187,75 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + queue.append((cmd, conn)) + + # if we're already processing this stream, there's nothing more to do: + # the new entry on the queue will get picked up in due course + if stream_name in self._processing_streams: + return + + # fire off a background process to start processing the queue. + run_as_background_process( + "process-replication-data", self._unsafe_process_queue, stream_name + ) + + async def _unsafe_process_queue(self, stream_name: str): + """Processes the command queue for the given stream, until it is empty + + Does not check if there is already a thread processing the queue, hence "unsafe" + """ + assert stream_name not in self._processing_streams + + self._processing_streams.add(stream_name) + try: + queue = self._command_queues_by_stream.get(stream_name) + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ if hs.config.redis.redis_enabled: + import txredisapi + from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) - import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", @@ -198,7 +302,7 @@ class ReplicationCommandHandler: """ return self._streams_to_replicate - async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) def send_positions_to_connection(self, conn: AbstractConnection): @@ -217,57 +321,73 @@ class ReplicationCommandHandler: ) ) - async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): + def on_USER_SYNC( + self, conn: AbstractConnection, cmd: UserSyncCommand + ) -> Optional[Awaitable[None]]: user_sync_counter.inc() if self._is_master: - await self._presence_handler.update_external_syncs_row( + return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + else: + return None - async def on_CLEAR_USER_SYNC( + def on_CLEAR_USER_SYNC( self, conn: AbstractConnection, cmd: ClearUserSyncsCommand - ): + ) -> Optional[Awaitable[None]]: if self._is_master: - await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + return self._presence_handler.update_external_syncs_clear(cmd.instance_id) + else: + return None - async def on_FEDERATION_ACK( - self, conn: AbstractConnection, cmd: FederationAckCommand - ): + def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - async def on_REMOVE_PUSHER( + def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand - ): + ) -> Optional[Awaitable[None]]: remove_pusher_counter.inc() if self._is_master: - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) + return self._handle_remove_pusher(cmd) + else: + return None + + async def _handle_remove_pusher(self, cmd: RemovePusherCommand): + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) - self._notifier.on_new_replication_data() + self._notifier.on_new_replication_data() - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): + def on_USER_IP( + self, conn: AbstractConnection, cmd: UserIpCommand + ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() if self._is_master: - await self._store.insert_client_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) + return self._handle_user_ip(cmd) + else: + return None + + async def _handle_user_ip(self, cmd: UserIpCommand): + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) - if self._server_notices_sender: - await self._server_notices_sender.on_user_ip(cmd.user_id) + assert self._server_notices_sender is not None + await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -275,42 +395,71 @@ class ReplicationCommandHandler: stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -329,78 +478,74 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + # TODO: add some tests for this - async def on_REMOTE_SERVER_UP( - self, conn: AbstractConnection, cmd: RemoteServerUpCommand - ): + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) + + self._streams_by_connection.setdefault(conn, set()).add(stream_name) + + def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -505,7 +650,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 4198eece71..0b0d204e64 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -50,6 +50,7 @@ import abc import fcntl import logging import struct +from inspect import isawaitable from typing import TYPE_CHECKING, List from prometheus_client import Counter @@ -57,8 +58,12 @@ from prometheus_client import Counter from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -108,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER -class ConnectionStates(object): +class ConnectionStates: CONNECTING = "connecting" ESTABLISHED = "established" PAUSED = "paused" @@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. + `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine; + if so, that will get run as a background process. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + ctx_name = "replication-conn-%s" % self.conn_id + self._logging_context = BackgroundProcessLoggingContext(ctx_name) + self._logging_context.request = ctx_name + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def lineReceived(self, line: bytes): """Called when we've received a line """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_line(line) + + def _parse_and_dispatch_line(self, line: bytes): if line.strip() == "": # Ignore blank lines return @@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. First calls `self.on_<COMMAND>` if it exists, then calls - `self.command_handler.on_<COMMAND>` if it exists. This allows for - protocol level handling of commands (e.g. PINGs), before delegating to - the handler. + `self.command_handler.on_<COMMAND>` if it exists (which can optionally + return an Awaitable). + + This allows for protocol level handling of commands (e.g. PINGs), before + delegating to the handler. Args: cmd: received command @@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # specific handling. cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + cmd_func(cmd) handled = True # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(self, cmd) + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) + handled = True if not handled: @@ -317,7 +342,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - async def on_PING(self, line): + def on_PING(self, line): self.received_ping = True - async def on_ERROR(self, cmd): + def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def __str__(self): addr = None if self.transport: @@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ServerCommand(self.server_name)) super().connectionMade() - async def on_NAME(self, cmd): + def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Once we've connected subscribe to the necessary streams self.replicate() - async def on_SERVER(self, cmd): + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index e776b63183..f225e533de 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -14,12 +14,16 @@ # limitations under the License. import logging +from inspect import isawaitable from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, @@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): stream_name = None # type: str outbound_redis_connection = None # type: txredisapi.RedisProtocol + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) + def connectionMade(self): logger.info("Connected to redis") super().connectionMade() @@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_message(message) + def _parse_and_dispatch_message(self, message: str): if message.strip() == "": # Ignore blank lines return @@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND>, which should return an awaitable. + Delegates to `self.handler.on_<COMMAND>` (which can optionally return an + Awaitable). Args: cmd: received command """ - handled = False - - # First call any command handlers on this instance. These are for redis - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(self, cmd) - handled = True - - if not handled: + if not cmd_func: logger.warning("Unhandled command: %r", cmd) + return + + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) self.handler.lost_connection(self) + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def send_command(self, cmd: Command): """Send a command if connection has been established. @@ -177,7 +192,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): Args: hs outbound_redis_connection: A connection to redis that will be used to - send outbound commands (this is seperate to the redis connection + send outbound commands (this is separate to the redis connection used to subscribe). """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 41569305df..04d894fb3d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): ) -class ReplicationStreamer(object): +class ReplicationStreamer: """Handles replication connections. This needs to be poked when new replication data may be available. When new diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4acefc8a96..682d47f402 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] -class Stream(object): +class Stream: """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last @@ -198,26 +198,6 @@ def current_token_without_instance( return lambda instance_name: current_token() -def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] -) -> UpdateFunction: - """Wraps a db query function which returns a list of rows to make it - suitable for use as an `update_function` for the Stream class - """ - - async def update_function(instance_name, from_token, upto_token, limit): - rows = await query_function(from_token, upto_token, limit) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return update_function - - def make_http_update_function(hs, stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. @@ -264,7 +244,7 @@ class BackfillStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_backfill_token), - db_query_to_update_function(store.get_all_new_backfill_event_rows), + store.get_all_new_backfill_event_rows, ) @@ -291,9 +271,7 @@ class PresenceStream(Stream): if hs.config.worker_app is None: # on the master, query the presence handler presence_handler = hs.get_presence_handler() - update_function = db_query_to_update_function( - presence_handler.get_all_presence_updates - ) + update_function = presence_handler.get_all_presence_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -316,13 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler - update_function = db_query_to_update_function( - typing_handler.get_all_typing_updates - ) + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler + update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( @@ -352,7 +329,7 @@ class ReceiptsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), - db_query_to_update_function(store.get_all_updated_receipts), + store.get_all_updated_receipts, ) @@ -367,26 +344,17 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() + super(PushRulesStream, self).__init__( - hs.get_instance_name(), self._current_token, self._update_function + hs.get_instance_name(), + self._current_token, + self.store.get_all_push_rule_updates, ) def _current_token(self, instance_name: str) -> int: - push_rules_token, _ = self.store.get_push_rules_stream_token() + push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token - async def _update_function( - self, instance_name: str, from_token: Token, to_token: Token, limit: int - ): - rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - - limited = False - if len(rows) == limit: - to_token = rows[-1][0] - limited = True - - return [(row[0], (row[2],)) for row in rows], to_token, limited - class PushersStream(Stream): """A user has added/changed/removed a pusher @@ -406,7 +374,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_pushers_stream_token), - db_query_to_update_function(store.get_all_updated_pushers_rows), + store.get_all_updated_pushers_rows, ) @@ -434,27 +402,13 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - self.store = hs.get_datastore() + store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_cache_stream_token, - self._update_function, + store.get_cache_stream_token_for_writer, + store.get_all_updated_caches, ) - async def _update_function( - self, instance_name: str, from_token: int, upto_token: int, limit: int - ): - rows = await self.store.get_all_updated_caches( - instance_name, from_token, upto_token, limit - ) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - class PublicRoomsStream(Stream): """The public rooms list changed @@ -478,7 +432,7 @@ class PublicRoomsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_public_room_stream_id), - db_query_to_update_function(store.get_all_new_public_rooms), + store.get_all_new_public_rooms, ) @@ -499,7 +453,7 @@ class DeviceListsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + store.get_all_device_list_changes_for_remotes, ) @@ -517,7 +471,7 @@ class ToDeviceStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), - db_query_to_update_function(store.get_all_new_device_messages), + store.get_all_new_device_messages, ) @@ -537,7 +491,7 @@ class TagAccountDataStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), - db_query_to_update_function(store.get_all_updated_tags), + store.get_all_updated_tags, ) @@ -625,7 +579,7 @@ class GroupServerStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), - db_query_to_update_function(store.get_all_groups_changes), + store.get_all_groups_changes, ) @@ -643,7 +597,5 @@ class UserSignatureStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function( - store.get_all_user_signature_changes_for_remotes - ), + store.get_all_user_signature_changes_for_remotes, ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f370390331..f929fc3954 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,16 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance - """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -51,20 +49,20 @@ data part are: @attr.s(slots=True, frozen=True) -class EventsStreamRow(object): +class EventsStreamRow: """A parsed row from the events replication stream""" type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow -class BaseEventsStreamRow(object): +class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. Specifies how to identify, serialize and deserialize the different types. """ - # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overridden in sub classes. TypeId = None # type: str @classmethod diff --git a/synapse/res/templates/mail-Element.css b/synapse/res/templates/mail-Element.css new file mode 100644 index 0000000000..6a3e36eda1 --- /dev/null +++ b/synapse/res/templates/mail-Element.css @@ -0,0 +1,7 @@ +.header { + border-bottom: 4px solid #e4f7ed ! important; +} + +.notif_link a, .footer a { + color: #76CFA6 ! important; +} diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html index 6b94d8c367..d87311f659 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html @@ -22,6 +22,8 @@ <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> {% elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% elif app_name == "Element" %} + <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> {% else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> {% endif %} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html index 019506e5fb..a2dfeb9e9f 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html @@ -22,6 +22,8 @@ <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> {% elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> + {% elif app_name == "Element" %} + <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> {% else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> {% endif %} diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html index bfd6449c5d..01cd9bdaf3 100644 --- a/synapse/res/templates/saml_error.html +++ b/synapse/res/templates/saml_error.html @@ -2,10 +2,17 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>SSO error</title> + <title>SSO login error</title> </head> <body> - <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p> +{# a 403 means we have actively rejected their login #} +{% if code == 403 %} + <p>You are not allowed to log in here.</p> +{% else %} + <p> + There was an error during authentication: + </p> + <div id="errormsg" style="margin:20px 80px">{{ msg }}</div> <p> If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the @@ -37,9 +44,9 @@ // to print one. let errorDesc = new URLSearchParams(searchStr).get("error_description") if (errorDesc) { - - document.getElementById("errormsg").innerText = ` ("${errorDesc}")`; + document.getElementById("errormsg").innerText = errorDesc; } </script> +{% endif %} </body> -</html> \ No newline at end of file +</html> diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 46e458e95b..87f927890c 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import ( room_keys, room_upgrade_rest_servlet, sendtodevice, + shared_rooms, sync, tags, thirdparty, @@ -125,3 +126,6 @@ class ClientRestResource(JsonResource): synapse.rest.admin.register_servlets_for_client_rest_resource( hs, client_resource ) + + # unstable + shared_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 9eda592de9..1c88c93f38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -35,8 +35,10 @@ from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( + DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, ) @@ -200,6 +202,8 @@ def register_servlets(hs, http_server): register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) + RoomMembersRestServlet(hs).register(http_server) + DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 8173baef8f..09726d52d6 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import List, Optional -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -30,9 +31,8 @@ from synapse.rest.admin._base import ( assert_user_is_admin, historical_admin_path_patterns, ) -from synapse.storage.data_stores.main.room import RoomSortOrder +from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -46,20 +46,10 @@ class ShutdownRoomRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") - DEFAULT_MESSAGE = ( - "Sharing illegal content on this server is not permitted and rooms in" - " violation will be blocked." - ) - def __init__(self, hs): self.hs = hs - self.store = hs.get_datastore() - self.state = hs.get_state_handler() - self._room_creation_handler = hs.get_room_creation_handler() - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - self._replication = hs.get_replication_data_handler() + self.room_shutdown_handler = hs.get_room_shutdown_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -67,116 +57,74 @@ class ShutdownRoomRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ["new_room_user_id"]) - new_room_user_id = content["new_room_user_id"] - - room_creator_requester = create_requester(new_room_user_id) - message = content.get("message", self.DEFAULT_MESSAGE) - room_name = content.get("room_name", "Content Violation Notification") - - info, stream_id = await self._room_creation_handler.create_room( - room_creator_requester, - config={ - "preset": "public_chat", - "name": room_name, - "power_level_content_override": {"users_default": -10}, - }, - ratelimit=False, + ret = await self.room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content["new_room_user_id"], + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=True, ) - new_room_id = info["room_id"] - requester_user_id = requester.user.to_string() + return (200, ret) - logger.info( - "Shutting down room %r, joining to new room: %r", room_id, new_room_id - ) - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. - await self.store.block_room(room_id, requester_user_id) - - # We now wait for the create room to come back in via replication so - # that we can assume that all the joins/invites have propogated before - # we try and auto join below. - # - # TODO: Currently the events stream is written to from master - await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id - ) - - users = await self.state.get_current_users_in_room(room_id) - kicked_users = [] - failed_to_kick_users = [] - for user_id in users: - if not self.hs.is_mine_id(user_id): - continue +class DeleteRoomRestServlet(RestServlet): + """Delete a room from server. It is a combination and improvement of + shut down and purge room. + Shuts down a room by removing all local users from the room. + Blocking all future invites and joins to the room is optional. + If desired any local aliases will be repointed to a new room + created by `new_room_user_id` and kicked users will be auto + joined to the new room. + It will remove all trace of a room from the database. + """ - logger.info("Kicking %r from %r...", user_id, room_id) + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$") - try: - target_requester = create_requester(user_id) - _, stream_id = await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=room_id, - action=Membership.LEAVE, - content={}, - ratelimit=False, - require_consent=False, - ) + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() - # Wait for leave to come in over replication before trying to forget. - await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id - ) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) - await self.room_member_handler.forget(target_requester.user, room_id) + content = parse_json_object_from_request(request) - await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=new_room_id, - action=Membership.JOIN, - content={}, - ratelimit=False, - require_consent=False, - ) + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) - kicked_users.append(user_id) - except Exception: - logger.exception( - "Failed to leave old room and join new room for %r", user_id - ) - failed_to_kick_users.append(user_id) - - await self.event_creation_handler.create_and_send_nonmember_event( - room_creator_requester, - { - "type": "m.room.message", - "content": {"body": message, "msgtype": "m.text"}, - "room_id": new_room_id, - "sender": new_room_user_id, - }, - ratelimit=False, - ) + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) + ret = await self.room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, ) - await self.store.update_aliases_for_room( - room_id, new_room_id, requester_user_id - ) + # Purge room + if purge: + await self.pagination_handler.purge_room(room_id) - return ( - 200, - { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - }, - ) + return (200, ret) class ListRoomRestServlet(RestServlet): @@ -292,6 +240,31 @@ class RoomRestServlet(RestServlet): return 200, ret +class RoomMembersRestServlet(RestServlet): + """ + Get members list of a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + members = await self.store.get_users_in_room(room_id) + ret = {"members": members, "total": len(members)} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") @@ -343,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet): join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + # update_membership with an action of "invite" can raise a + # ShadowBanError. This is not handled since it is assumed that + # an admin isn't going to call this API with a shadow-banned user. await self.room_member_handler.update_membership( requester=requester, target=fake_requester.user, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index fefc8f71fa..f3e77da850 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -16,9 +16,7 @@ import hashlib import hmac import logging import re - -from six import text_type -from six.moves import http_client +from http import HTTPStatus from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -75,6 +73,7 @@ class UsersRestServletV2(RestServlet): The parameters `from` and `limit` are required only for pagination. By default, a `limit` of 100 is used. The parameter `user_id` can be used to filter by user id. + The parameter `name` can be used to filter by user id or display name. The parameter `guests` can be used to exclude guest users. The parameter `deactivated` can be used to include deactivated users. """ @@ -91,11 +90,12 @@ class UsersRestServletV2(RestServlet): start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) user_id = parse_string(request, "user_id", default=None) + name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) users, total = await self.store.get_users_paginate( - start, limit, user_id, guests, deactivated + start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} if len(users) >= limit: @@ -215,10 +215,7 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: - if ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): + if not isinstance(body["password"], str) or len(body["password"]) > 512: raise SynapseError(400, "Invalid password") else: new_password = body["password"] @@ -244,6 +241,15 @@ class UserRestServletV2(RestServlet): await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False ) + elif not deactivate and user["deactivated"]: + if "password" not in body: + raise SynapseError( + 400, "Must provide a password to re-activate an account." + ) + + await self.deactivate_account_handler.activate_account( + target_user.to_string() + ) user = await self.admin_handler.get_user(target_user) return 200, user @@ -252,14 +258,13 @@ class UserRestServletV2(RestServlet): password = body.get("password") password_hash = None if password is not None: - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) admin = body.get("admin", None) user_type = body.get("user_type", None) displayname = body.get("displayname", None) - threepids = body.get("threepids", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") @@ -370,10 +375,7 @@ class UserRegisterServlet(RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON ) else: - if ( - not isinstance(body["username"], text_type) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -386,7 +388,7 @@ class UserRegisterServlet(RestServlet): ) else: password = body["password"] - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_bytes = password.encode("utf-8") @@ -477,7 +479,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 6da71dc46f..7be5c0fb88 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins -class HttpTransactionCache(object): +class HttpTransactionCache: def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 5934b1fe8b..b210015173 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler try: - service = await self.auth.get_appservice_by_req(request) + service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) await dir_handler.delete_appservice_association(service, room_alias) logger.info( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index dceb2792fa..a14618ac84 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,9 +14,14 @@ # limitations under the License. import logging +from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.handlers.auth import ( + convert_client_dict_legacy_fields_to_identifier, + login_id_phone_to_thirdparty, +) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -26,64 +31,36 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder -from synapse.types import UserID -from synapse.util.msisdn import phone_number_to_msisdn +from synapse.types import JsonDict, UserID +from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) -def login_submission_legacy_convert(submission): - """ - If the input login submission is an old style object - (ie. with top-level user / medium / address) convert it - to a typed object. - """ - if "user" in submission: - submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} - del submission["user"] - - if "medium" in submission and "address" in submission: - submission["identifier"] = { - "type": "m.id.thirdparty", - "medium": submission["medium"], - "address": submission["address"], - } - del submission["medium"] - del submission["address"] - - -def login_id_thirdparty_from_phone(identifier): - """ - Convert a phone login identifier type to a generic threepid identifier - Args: - identifier(dict): Login identifier dict of type 'm.id.phone' - - Returns: Login identifier dict of type 'm.id.threepid' - """ - if "country" not in identifier or "number" not in identifier: - raise SynapseError(400, "Invalid phone-type identifier") - - msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) - - return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} - - class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" - JWT_TYPE = "m.login.jwt" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__() self.hs = hs + + # JWT configuration variables. self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.jwt_issuer = hs.config.jwt_issuer + self.jwt_audiences = hs.config.jwt_audiences + + # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -104,10 +81,11 @@ class LoginRestServlet(RestServlet): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) - def on_GET(self, request): + def on_GET(self, request: SynapseRequest): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it @@ -131,20 +109,21 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} - def on_OPTIONS(self, request): + def on_OPTIONS(self, request: SynapseRequest): return 200, {} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest): self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - result = await self.do_jwt_login(login_submission) + result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = await self.do_token_login(login_submission) + result = await self._do_token_login(login_submission) else: result = await self._do_other_login(login_submission) except KeyError: @@ -155,14 +134,14 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - async def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins Args: login_submission: Returns: - dict: HTTP response + HTTP response """ # Log the request we got, but only certain fields to minimise the chance of # logging someone's password (even if they accidentally put it in the wrong @@ -174,18 +153,11 @@ class LoginRestServlet(RestServlet): login_submission.get("address"), login_submission.get("user"), ) - login_submission_legacy_convert(login_submission) - - if "identifier" not in login_submission: - raise SynapseError(400, "Missing param: identifier") - - identifier = login_submission["identifier"] - if "type" not in identifier: - raise SynapseError(400, "Login identifier has no type") + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) # convert phone type identifiers to generic threepids if identifier["type"] == "m.id.phone": - identifier = login_id_thirdparty_from_phone(identifier) + identifier = login_id_phone_to_thirdparty(identifier) # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": @@ -195,11 +167,14 @@ class LoginRestServlet(RestServlet): if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) if medium == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - address = address.lower() + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. @@ -277,25 +252,30 @@ class LoginRestServlet(RestServlet): return result async def _complete_login( - self, user_id, login_submission, callback=None, create_non_existent_users=False - ): + self, + user_id: str, + login_submission: JsonDict, + callback: Optional[ + Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] + ] = None, + create_non_existent_users: bool = False, + ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on - all succesful logins. + all successful logins. - Applies the ratelimiting for succesful login attempts against an + Applies the ratelimiting for successful login attempts against an account. Args: - user_id (str): ID of the user to register. - login_submission (dict): Dictionary of login information. - callback (func|None): Callback function to run after registration. - create_non_existent_users (bool): Whether to create the user if - they don't exist. Defaults to False. + user_id: ID of the user to register. + login_submission: Dictionary of login information. + callback: Callback function to run after registration. + create_non_existent_users: Whether to create the user if they don't + exist. Defaults to False. Returns: - result (Dict[str,str]): Dictionary of account information after - successful registration. + result: Dictionary of account information after successful registration. """ # Before we actually log them in we check if they've already logged in @@ -329,7 +309,7 @@ class LoginRestServlet(RestServlet): return result - async def do_token_login(self, login_submission): + async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( @@ -339,28 +319,32 @@ class LoginRestServlet(RestServlet): result = await self._complete_login(user_id, login_submission) return result - async def do_jwt_login(self, login_submission): + async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN ) import jwt - from jwt.exceptions import InvalidTokenError try: payload = jwt.decode( - token, self.jwt_secret, algorithms=[self.jwt_algorithm] + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm], + issuer=self.jwt_issuer, + audience=self.jwt_audiences, + ) + except jwt.PyJWTError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, ) - except jwt.ExpiredSignatureError: - raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) - except InvalidTokenError: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user = payload.get("sub", None) if user is None: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index eec16f8ad8..970fdd5834 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,8 +17,6 @@ """ import logging -from six import string_types - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -51,7 +49,9 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state(state, self.clock.time_msec()) + state = format_user_presence_state( + state, self.clock.time_msec(), include_user_id=False + ) return 200, state @@ -71,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet): if "status_msg" in content: state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], string_types): + if not isinstance(state["status_msg"], str): raise SynapseError(400, "status_msg must be a string.") if content: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9fd4908136..e781a3bcf4 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from synapse.api.errors import ( NotFoundError, StoreError, @@ -25,7 +24,7 @@ from synapse.http.servlet import ( parse_json_value_from_request, parse_string, ) -from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client.v2_alpha._base import client_patterns @@ -45,6 +44,8 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -158,10 +159,10 @@ class PushRuleRestServlet(RestServlet): return 200, {} def notify_user(self, user_id): - stream_id, _ = self.store.get_push_rules_stream_token() + stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - def set_rule_attr(self, user_id, spec, val): + async def set_rule_attr(self, user_id, spec, val): if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] @@ -171,7 +172,9 @@ class PushRuleRestServlet(RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + return await self.store.set_push_rule_enabled( + user_id, namespaced_rule_id, val + ) elif spec["attr"] == "actions": actions = val.get("actions") _check_actions(actions) @@ -179,9 +182,14 @@ class PushRuleRestServlet(RestServlet): rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: + if user_id in self._users_new_default_push_rules: + rule_ids = NEW_RULE_IDS + else: + rule_ids = BASE_RULE_IDS + + if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.store.set_push_rule_actions( + return await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 550a2f1b44..5f65cb7d83 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -16,7 +16,7 @@ import logging from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html_bytes from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -177,13 +177,9 @@ class PushersRemoveRestServlet(RestServlet): self.notifier.on_new_replication_data() - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + respond_with_html_bytes( + request, 200, PushersRemoveRestServlet.SUCCESS_HTML, ) - request.write(PushersRemoveRestServlet.SUCCESS_HTML) - finish_request(request) return None def on_OPTIONS(self, _): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 105e0cf4d2..84baf3d59b 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -15,13 +15,11 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ + import logging import re from typing import List, Optional - -from six.moves.urllib import parse as urlparse - -from canonicaljson import json +from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( @@ -29,6 +27,7 @@ from synapse.api.errors import ( Codes, HttpResponseException, InvalidClientCredentialsError, + ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter @@ -46,6 +45,8 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder +from synapse.util.stringutils import random_string MYPY = False if MYPY: @@ -170,7 +171,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): room_id=room_id, event_type=event_type, state_key=state_key, - is_guest=requester.is_guest, ) if not data: @@ -200,28 +200,29 @@ class RoomStateEventRestServlet(TransactionRestServlet): if state_key is not None: event_dict["state_key"] = state_key - if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id + try: + if event_type == EventTypes.Member: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + else: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - ret = {} # type: dict - if event_id: - set_tag("event_id", event_id) - ret = {"event_id": event_id} + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -251,12 +252,19 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_GET(self, request, room_id, event_type, txn_id): return 200, "Not implemented" @@ -446,7 +454,7 @@ class RoomMemberListRestServlet(RestServlet): async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -518,10 +526,12 @@ class RoomMessageListRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args - filter_bytes = parse_string(request, b"filter", encoding=None) - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -630,10 +640,12 @@ class RoomEventContextServlet(RestServlet): limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_bytes = parse_string(request, "filter") - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] else: event_filter = None @@ -718,16 +730,20 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - content.get("id_access_token"), - ) + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return 200, {} target = requester.user @@ -739,15 +755,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content: event_content = {"reason": content["reason"]} - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return_value = {} @@ -785,20 +805,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_PUT(self, request, room_id, event_id, txn_id): set_tag("txn_id", txn_id) @@ -819,9 +846,18 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + async def on_PUT(self, request, room_id, user_id): requester = await self.auth.get_user_by_req(request) + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) @@ -832,17 +868,21 @@ class RoomTypingRestServlet(RestServlet): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) - if content["typing"]: - await self.typing_handler.started_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, - timeout=timeout, - ) - else: - await self.typing_handler.stopped_typing( - target_user=target_user, auth_user=requester.user, room_id=room_id - ) + try: + if content["typing"]: + await self.typing_handler.started_typing( + target_user=target_user, + requester=requester, + room_id=room_id, + timeout=timeout, + ) + else: + await self.typing_handler.stopped_typing( + target_user=target_user, requester=requester, room_id=room_id + ) + except ShadowBanError: + # Pretend this worked without error. + pass return 200, {} diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 747d46eac2..50277c6cf6 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet): # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the # same result as the TURN server. - password = base64.b64encode(mac.digest()) + password = base64.b64encode(mac.digest()).decode("ascii") elif turnUris and turnUsername and turnPassword and userLifetime: username = turnUsername diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index bc11b4dda4..f016b4f1bd 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -17,24 +17,32 @@ """ import logging import re - -from twisted.internet import defer +from typing import Iterable, Pattern from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.types import JsonDict logger = logging.getLogger(__name__) -def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): +def client_patterns( + path_regex: str, + releases: Iterable[int] = (0,), + unstable: bool = True, + v1: bool = False, +) -> Iterable[Pattern]: """Creates a regex compiled client path with the correct client path prefix. Args: - path_regex (str): The regex string to match. This should NOT have a ^ + path_regex: The regex string to match. This should NOT have a ^ as this will be prefixed. + releases: An iterable of releases to include this endpoint under. + unstable: If true, include this endpoint under the "unstable" prefix. + v1: If true, include this endpoint under the "api/v1" prefix. Returns: - SRE_Pattern + An iterable of patterns. """ patterns = [] @@ -51,7 +59,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): return patterns -def set_timeline_upper_limit(filter_json, filter_timeline_limit): +def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: + """ + Enforces a maximum limit of a timeline query. + + Params: + filter_json: The timeline query to modify. + filter_timeline_limit: The maximum limit to allow, passing -1 will + disable enforcing a maximum limit. + """ if filter_timeline_limit < 0: return # no upper limits timeline = filter_json.get("room", {}).get("timeline", {}) @@ -64,34 +80,22 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit): def interactive_auth_handler(orig): """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors - Takes a on_POST method which returns a deferred (errcode, body) response + Takes a on_POST method which returns an Awaitable (errcode, body) response and adds exception handling to turn a InteractiveAuthIncompleteError into a 401 response. Normal usage is: @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): # ... - yield self.auth_handler.check_auth - """ + await self.auth_handler.check_auth + """ - def wrapped(*args, **kwargs): - res = defer.ensureDeferred(orig(*args, **kwargs)) - res.addErrback(_catch_incomplete_interactive_auth) - return res + async def wrapped(*args, **kwargs): + try: + return await orig(*args, **kwargs) + except InteractiveAuthIncompleteError as e: + return 401, e.result return wrapped - - -def _catch_incomplete_interactive_auth(f): - """helper for interactive_auth_handler - - Catches InteractiveAuthIncompleteErrors and turns them into 401 responses - - Args: - f (failure.Failure): - """ - f.trap(InteractiveAuthIncompleteError) - return 401, f.value.result diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 8f9440da9a..cad3f9bbb7 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,23 +15,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six.moves import http_client +import random +from http import HTTPStatus from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) from synapse.config.emailconfig import ThreepidBehaviour -from synapse.http.server import finish_request +from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed +from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -49,21 +54,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_password_reset_template_html, - self.config.email_password_reset_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, ) async def on_POST(self, request): @@ -84,7 +79,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. This allows the user to reset his password without having to + # know the exact spelling (eg. upper and lower case) of address in the database. + # Stored in the database "foo@bar.com" + # User requests with "FOO@bar.com" would raise a Not Found error + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -95,6 +98,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # The email will be sent to the stored address. + # This avoids a potential account hijack by requesting a password reset to + # an email address which is controlled by the attacker but which, after + # canonicalisation, matches the one in our database. existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -103,6 +110,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -153,9 +163,8 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], + self._failure_email_template = ( + self.config.email_password_reset_template_failure_html ) async def on_GET(self, request, medium): @@ -198,17 +207,16 @@ class PasswordResetSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_password_reset_template_success_html - request.setResponseCode(200) + html = self.config.email_password_reset_template_success_html_content + status_code = 200 except ThreepidValidationError as e: - request.setResponseCode(e.code) + status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) - request.write(html.encode("utf-8")) - finish_request(request) + respond_with_html(request, status_code, html) class PasswordRestServlet(RestServlet): @@ -229,18 +237,12 @@ class PasswordRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the new password provided to us. - if "new_password" in body: - new_password = body.pop("new_password") + new_password = body.pop("new_password", None) + if new_password is not None: if not isinstance(new_password, str) or len(new_password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(new_password) - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "new_password_hash" in body: - raise SynapseError(400, "Unexpected property: new_password_hash") - body["new_password_hash"] = await self.auth_handler.hash(new_password) - # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -253,33 +255,62 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - params = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + params, session_id = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise user_id = requester.user.to_string() else: requester = None - result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") if threepid["medium"] == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid["address"] = threepid["address"].lower() + try: + threepid["address"] = canonicalise_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) # if using email, we must know about the email they're authing with! threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] @@ -291,12 +322,21 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password_hash"]) - new_password_hash = params["new_password_hash"] + # If we have a password in this request, prefer it. Otherwise, there + # must be a password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + else: + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + logout_devices = params.get("logout_devices", True) await self._set_password_handler.set_password( - user_id, new_password_hash, logout_devices, requester + user_id, password_hash, logout_devices, requester ) return 200, {} @@ -321,7 +361,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) @@ -364,19 +404,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_add_threepid_template_html, - self.config.email_add_threepid_template_text, - ], - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, ) async def on_POST(self, request): @@ -394,7 +426,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. + # This ensures that the validation email is sent to the canonicalised address + # as it will later be entered into the database. + # Otherwise the email will be sent to "FOO@bar.com" and stored as + # "foo@bar.com" in database. + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -405,14 +446,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = await self.store.get_user_id_by_threepid( - "email", body["email"] - ) + existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -481,6 +523,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -524,9 +569,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_add_threepid_template_failure_html], + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html ) async def on_GET(self, request): @@ -571,16 +615,15 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Otherwise show the success template html = self.config.email_add_threepid_template_success_html_content - request.setResponseCode(200) + status_code = 200 except ThreepidValidationError as e: - request.setResponseCode(e.code) + status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) - request.write(html.encode("utf-8")) - finish_request(request) + respond_with_html(request, status_code, html) class AddThreepidMsisdnSubmitTokenServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2f10fa64e2..d06336ceea 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -16,7 +16,7 @@ import logging from synapse.api.errors import AuthError, SynapseError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet from ._base import client_patterns @@ -26,9 +26,6 @@ logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - SUCCESS_HTML = ( - b"<html><body>Your account has been successfully renewed.</body><html>" - ) def __init__(self, hs): """ @@ -59,11 +56,7 @@ class AccountValidityRenewServlet(RestServlet): status_code = 404 response = self.failure_html - request.setResponseCode(status_code) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(response),)) - request.write(response.encode("utf8")) - finish_request(request) + respond_with_html(request, status_code, response) class AccountValiditySendMailServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 75590ebaeb..8e585e9153 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,7 +18,7 @@ import logging from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string from ._base import client_patterns @@ -200,13 +200,7 @@ class AuthRestServlet(RestServlet): raise SynapseError(404, "Unknown auth stage type") # Render the HTML and return. - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) + respond_with_html(request, 200, html) return None async def on_POST(self, request, stagetype): @@ -263,13 +257,7 @@ class AuthRestServlet(RestServlet): raise SynapseError(404, "Unknown auth stage type") # Render the HTML and return. - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) + respond_with_html(request, 200, html) return None def on_OPTIONS(self, _): diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 1efe60f3a7..075afdd32b 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,6 +16,7 @@ import logging +from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s was not legal group ID" % (group_id,)) + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e2efc47024..ae1a8c4e6c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,16 +16,16 @@ import hmac import logging +import random from typing import List, Union -from six import string_types - import synapse import synapse.api.auth import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, + InteractiveAuthIncompleteError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -38,18 +38,18 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler -from synapse.http.server import finish_request +from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed +from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -82,23 +82,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - from synapse.push.mailer import Mailer, load_jinja2_templates - - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_registration_template_html, - self.config.email_registration_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, ) async def on_POST(self, request): @@ -118,7 +106,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/v2_alpha/account.py) + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -130,13 +125,16 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", body["email"] + "email", email ) if existing_user_id is not None: if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -209,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( @@ -256,15 +257,8 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], + self._failure_email_template = ( + self.config.email_registration_template_failure_html ) async def on_GET(self, request, medium): @@ -306,17 +300,15 @@ class RegistrationSubmitTokenServlet(RestServlet): # Otherwise show the success template html = self.config.email_registration_template_success_html_content - - request.setResponseCode(200) + status_code = 200 except ThreepidValidationError as e: - request.setResponseCode(e.code) + status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) - request.write(html.encode("utf-8")) - finish_request(request) + respond_with_html(request, status_code, html) class UsernameAvailabilityRestServlet(RestServlet): @@ -384,6 +376,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -409,32 +402,17 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the username/password provided to us. - if "password" in body: - password = body.pop("password") - if not isinstance(password, string_types) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "password_hash" in body: - raise SynapseError(400, "Unexpected property: password_hash") - body["password_hash"] = await self.auth_handler.hash(password) - + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. desired_username = None if "username" in body: - if ( - not isinstance(body["username"], string_types) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") desired_username = body["username"] appservice = None if self.auth.has_access_token(request): - appservice = await self.auth.get_appservice_by_req(request) + appservice = self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -453,28 +431,41 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if isinstance(desired_username, string_types): + if isinstance(desired_username, str): result = await self._do_appservice_registration( desired_username, access_token, body ) return 200, result # we throw for non 200 responses - # for regular registration, downcase the provided username before - # attempting to register it. This should mean - # that people who try to register with upper-case in their usernames - # don't get a nasty surprise. (Note that we treat username - # case-insenstively in login, so they are free to carry on imagining - # that their username is CrAzYh4cKeR if that keeps them happy) - if desired_username is not None: - desired_username = desired_username.lower() - # == Normal User Registration == (everyone else) - if not self.hs.config.enable_registration: + if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled") + # For regular registration, convert the provided username to lowercase + # before attempting to register it. This should mean that people who try + # to register with upper-case in their usernames don't get a nasty surprise. + # + # Note that we treat usernames case-insensitively in login, so they are + # free to carry on imagining that their username is CrAzYh4cKeR if that + # keeps them happy. + if desired_username is not None: + desired_username = desired_username.lower() + + # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password_hash" not in body: + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + password = body.pop("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out @@ -484,6 +475,7 @@ class RegisterRestServlet(RestServlet): session_id = self.auth_handler.get_session_id(body) registered_user_id = None + password_hash = None if session_id: # if we get a registered user id out of here, it means we previously # registered a user for this session, so we could just return the @@ -492,7 +484,12 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + # Ensure that the username is valid. if desired_username is not None: await self.registration_handler.check_username( desired_username, @@ -500,20 +497,38 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) - auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", - ) + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise # Check that we're not trying to register a denied 3pid. # # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. - if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: @@ -535,12 +550,15 @@ class RegisterRestServlet(RestServlet): # don't re-register the threepids registered = False else: - # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password_hash"]) + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -559,6 +577,15 @@ class RegisterRestServlet(RestServlet): if login_type in auth_result: medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/v2_alpha/account.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) existing_user_id = await self.store.get_user_id_by_threepid( medium, address @@ -571,12 +598,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password_hash=new_password_hash, + password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db @@ -586,8 +618,8 @@ class RegisterRestServlet(RestServlet): ): await self.store.upsert_monthly_active_user(registered_user_id) - # remember that we've now registered that user account, and with - # what user ID (since the user may not have specified) + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) @@ -626,7 +658,7 @@ class RegisterRestServlet(RestServlet): (object) params: registration parameters, from which we pull device_id, initial_device_name and inhibit_login Returns: - defer.Deferred: (object) dictionary for response from /register + dictionary for response from /register """ result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 89002ffbff..e29f49f7f5 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -22,7 +22,7 @@ any time to reflect changes in the MSC. import logging from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import ShadowBanError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -35,6 +35,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.util.stringutils import random_string from ._base import client_patterns @@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - return 200, {"event_id": event.event_id} + return 200, {"event_id": event_id} class RelationPaginationServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index f067b5edac..e15927c4ea 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -14,9 +14,7 @@ # limitations under the License. import logging - -from six import string_types -from six.moves import http_client +from http import HTTPStatus from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( @@ -47,15 +45,15 @@ class ReportEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ("reason", "score")) - if not isinstance(body["reason"], string_types): + if not isinstance(body["reason"], str): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'reason' must be a string", Codes.BAD_JSON, ) if not isinstance(body["score"], int): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index f357015a70..39a5518614 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,13 +15,14 @@ import logging -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.util import stringutils from ._base import client_patterns @@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) - new_version = content["new_version"] new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) if new_version is None: @@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) + try: + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + except ShadowBanError: + # Generate a random room ID. + new_room_id = stringutils.random_string(18) ret = {"replacement_room": new_room_id} diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py new file mode 100644 index 0000000000..2492634dac --- /dev/null +++ b/synapse/rest/client/v2_alpha/shared_rooms.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserSharedRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs): + super(UserSharedRoomsServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.user_directory_active = hs.config.update_user_directory + + async def on_GET(self, request, user_id): + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + rooms = await self.store.get_shared_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs, http_server): + UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 8fa68dd37f..a0b00135e1 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -16,8 +16,6 @@ import itertools import logging -from canonicaljson import json - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken +from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet): filter_collection = DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: - filter_object = json.loads(filter_id) + filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( filter_object, self.hs.config.filter_timeline_limit ) @@ -178,14 +177,22 @@ class SyncRestServlet(RestServlet): full_state=full_state, ) + # the client may have disconnected by now; don't bother to serialize the + # response if so. + if request._disconnected: + logger.info("Client has disconnected; not serializing response.") + return 200, {} + time_now = self.clock.time_msec() response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) + logger.debug("Event formatting complete") return 200, response_content async def encode_response(self, time_now, sync_result, access_token_id, filter): + logger.debug("Formatting events in sync response") if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -213,6 +220,7 @@ class SyncRestServlet(RestServlet): event_formatter, ) + logger.debug("building sync response dict") return { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, @@ -417,6 +425,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 0d668df0b6..24ac57f35d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.e2e_cross_signing": True, # Implements additional endpoints as described in MSC2432 "org.matrix.msc2432": True, + # Implements additional endpoints as described in MSC2666 + "uk.half-shot.msc2666": True, }, }, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 1ddf9997ff..b3e4d5612e 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -16,22 +16,15 @@ import hmac import logging from hashlib import sha256 +from http import HTTPStatus from os import path -from six.moves import http_client - import jinja2 from jinja2 import TemplateNotFound -from twisted.internet import defer - from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError -from synapse.http.server import ( - DirectServeResource, - finish_request, - wrap_html_request_handler, -) +from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string from synapse.types import UserID @@ -49,7 +42,7 @@ else: return a == b -class ConsentResource(DirectServeResource): +class ConsentResource(DirectServeHtmlResource): """A twisted Resource to display a privacy policy and gather consent to it When accessed via GET, returns the privacy policy via a template. @@ -120,7 +113,6 @@ class ConsentResource(DirectServeResource): self._hmac_secret = hs.config.form_secret.encode("utf-8") - @wrap_html_request_handler async def _async_render_GET(self, request): """ Args: @@ -141,7 +133,7 @@ class ConsentResource(DirectServeResource): else: qualified_user_id = UserID(username, self.hs.hostname).to_string() - u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id) + u = await self.store.get_user_by_id(qualified_user_id) if u is None: raise NotFoundError("Unknown user") @@ -161,7 +153,6 @@ class ConsentResource(DirectServeResource): except TemplateNotFound: raise NotFoundError("Unknown policy version") - @wrap_html_request_handler async def _async_render_POST(self, request): """ Args: @@ -197,12 +188,8 @@ class ConsentResource(DirectServeResource): template_html = self._jinja_env.get_template( path.join(TEMPLATE_LANGUAGE, template_name) ) - html_bytes = template_html.render(**template_args).encode("utf8") - - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) - request.write(html_bytes) - finish_request(request) + html = template_html.render(**template_args) + respond_with_html(request, 200, html) def _check_hash(self, userid, userhmac): """ @@ -223,4 +210,4 @@ class ConsentResource(DirectServeResource): ) if not compare_digest(want_mac, userhmac): - raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/health.py b/synapse/rest/health.py new file mode 100644 index 0000000000..0170950bf3 --- /dev/null +++ b/synapse/rest/health.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource + + +class HealthResource(Resource): + """A resource that does nothing except return a 200 with a body of `OK`, + which can be used as a health check. + + Note: `SynapseRequest._should_log_request` ensures that requests to + `/health` do not get logged at INFO. + """ + + isLeaf = 1 + + def render_GET(self, request): + request.setHeader(b"Content-Type", b"text/plain") + return b"OK" diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index ab671f7334..5db7f81c2d 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,23 +15,19 @@ import logging from typing import Dict, Set -from canonicaljson import encode_canonical_json, json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import ( - DirectServeResource, - respond_with_json_bytes, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.util import json_decoder logger = logging.getLogger(__name__) -class RemoteKey(DirectServeResource): - """HTTP resource for retreiving the TLS certificate and NACL signature +class RemoteKey(DirectServeJsonResource): + """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks that the NACL signature for the remote server is valid. Returns a dict of @@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource): isLeaf = True def __init__(self, hs): + super().__init__() + self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.config = hs.config - @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: (server,) = request.postpath @@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource): await self.query_keys(request, query, query_remote_on_cache_miss=True) - @wrap_json_request_handler async def _async_render_POST(self, request): content = parse_json_object_from_request(request) @@ -206,18 +202,22 @@ class RemoteKey(DirectServeResource): if miss: cache_misses.setdefault(server_name, set()).add(key_id) + # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: for ts_added, result in results: + # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) + # If there is a cache miss, request the missing keys, then recurse (and + # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await self.fetcher.get_keys(cache_misses) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json) + key_json = json_decoder.decode(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) @@ -225,4 +225,4 @@ class RemoteKey(DirectServeResource): results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json(request, 200, results, canonical_json=True) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 3689777266..6568e61829 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -16,10 +16,10 @@ import logging import os +import urllib +from typing import Awaitable -from six.moves import urllib - -from twisted.internet import defer +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -78,8 +78,9 @@ def respond_404(request): ) -@defer.inlineCallbacks -def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None): +async def respond_with_file( + request, media_type, file_path, file_size=None, upload_name=None +): logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -90,7 +91,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) finish_request(request) else: @@ -199,8 +200,9 @@ def _can_encode_filename_as_token(x): return True -@defer.inlineCallbacks -def respond_with_responder(request, responder, media_type, file_size, upload_name=None): +async def respond_with_responder( + request, responder, media_type, file_size, upload_name=None +): """Responds to the request with given responder. If responder is None then returns 404. @@ -219,7 +221,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam add_file_headers(request, media_type, file_size, upload_name) try: with responder: - yield responder.write_to_consumer(request) + await responder.write_to_consumer(request) except Exception as e: # The majority of the time this will be due to the client having gone # away. Unfortunately, Twisted simply throws a generic exception at us @@ -233,21 +235,21 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam finish_request(request) -class Responder(object): +class Responder: """Represents a response that can be streamed to the requester. Responder is a context manager which *must* be used, so that any resources held can be cleaned up. """ - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer Args: - consumer (IConsumer) + consumer: The consumer to stream into. Returns: - Deferred: Resolves once the response has finished being written + Resolves once the response has finished being written """ pass @@ -258,7 +260,7 @@ class Responder(object): pass -class FileInfo(object): +class FileInfo: """Details about a requested/uploaded file. Attributes: diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 9f747de263..68dd2a1c8a 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -14,16 +14,10 @@ # limitations under the License. # -from twisted.web.server import NOT_DONE_YET +from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.server import ( - DirectServeResource, - respond_with_json, - wrap_json_request_handler, -) - -class MediaConfigResource(DirectServeResource): +class MediaConfigResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs): @@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource): self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - @wrap_json_request_handler async def _async_render_GET(self, request): await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - def render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request): respond_with_json(request, 200, {}, send_cors=True) - return NOT_DONE_YET diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 24d3ae5bbc..d3d8457303 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -15,18 +15,14 @@ import logging import synapse.http.servlet -from synapse.http.server import ( - DirectServeResource, - set_cors_headers, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, set_cors_headers from ._base import parse_media_id, respond_404 logger = logging.getLogger(__name__) -class DownloadResource(DirectServeResource): +class DownloadResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo): @@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource): self.media_repo = media_repo self.server_name = hs.hostname - # this is expected by @wrap_json_request_handler - self.clock = hs.get_clock() - - @wrap_json_request_handler async def _async_render_GET(self, request): set_cors_headers(request) request.setHeader( diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index e25c382c9c..d2826374a7 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -33,7 +33,7 @@ def _wrap_in_base_path(func): return _wrapped -class MediaFilePaths(object): +class MediaFilePaths: """Describes where files are stored on disk. Most of the functions have a `*_rel` variant which returns a file path that diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index fd10d42f2f..9a1b7779f7 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -18,12 +18,11 @@ import errno import logging import os import shutil -from typing import Dict, Tuple - -from six import iteritems +from typing import IO, Dict, Optional, Tuple import twisted.internet.error import twisted.web.http +from twisted.web.http import Request from twisted.web.resource import Resource from synapse.api.errors import ( @@ -42,6 +41,7 @@ from synapse.util.stringutils import random_string from ._base import ( FileInfo, + Responder, get_filename_from_headers, respond_404, respond_with_responder, @@ -62,7 +62,7 @@ logger = logging.getLogger(__name__) UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 -class MediaRepository(object): +class MediaRepository: def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() @@ -137,19 +137,24 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) async def create_content( - self, media_type, upload_name, content, content_length, auth_user - ): + self, + media_type: str, + upload_name: str, + content: IO, + content_length: int, + auth_user: str, + ) -> str: """Store uploaded content for a local user and return the mxc URL Args: - media_type(str): The content type of the file - upload_name(str): The name of the file + media_type: The content type of the file + upload_name: The name of the file content: A file like object that is the content to store - content_length(int): The length of the content - auth_user(str): The user_id of the uploader + content_length: The length of the content + auth_user: The user_id of the uploader Returns: - Deferred[str]: The mxc url of the stored content + The mxc url of the stored content """ media_id = random_string(24) @@ -172,19 +177,20 @@ class MediaRepository(object): return "mxc://%s/%s" % (self.server_name, media_id) - async def get_local_media(self, request, media_id, name): + async def get_local_media( + self, request: Request, media_id: str, name: Optional[str] + ) -> None: """Responds to reqests for local media, if exists, or returns 404. Args: - request(twisted.web.http.Request) - media_id (str): The media ID of the content. (This is the same as + request: The incoming request. + media_id: The media ID of the content. (This is the same as the file_id for local content.) - name (str|None): Optional name that, if specified, will be used as + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: @@ -205,20 +211,20 @@ class MediaRepository(object): request, responder, media_type, media_length, upload_name ) - async def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media( + self, request: Request, server_name: str, media_id: str, name: Optional[str] + ) -> None: """Respond to requests for remote media. Args: - request(twisted.web.http.Request) - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). - name (str|None): Optional name that, if specified, will be used as + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None @@ -247,17 +253,16 @@ class MediaRepository(object): else: respond_404(request) - async def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: """Gets the media info associated with the remote file, downloading if necessary. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: - Deferred[dict]: The media_info of the file + The media info of the file """ if ( self.federation_domain_whitelist is not None @@ -280,7 +285,9 @@ class MediaRepository(object): return media_info - async def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -290,7 +297,7 @@ class MediaRepository(object): remote server). Returns: - Deferred[(Responder, media_info)] + A tuple of responder and the media info of the file. """ media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -321,26 +328,28 @@ class MediaRepository(object): responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file( + self, server_name: str, media_id: str, file_id: str + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: - server_name (str): Originating server - media_id (str): The media ID of the content (as defined by the + server_name: Originating server + media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id (str): Local file ID + file_id: Local file ID Returns: - Deferred[MediaInfo] + The media info of the file. """ file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join( - ("/_matrix/media/v1/download", server_name, media_id) + ("/_matrix/media/r0/download", server_name, media_id) ) try: length, headers = await self.client.get_file( @@ -551,25 +560,31 @@ class MediaRepository(object): return output_path async def _generate_thumbnails( - self, server_name, media_id, file_id, media_type, url_cache=False - ): + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: """Generate and store thumbnails for an image. Args: - server_name (str|None): The server name if remote media, else None if local - media_id (str): The media ID of the content. (This is the same as + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as the file_id for local content) - file_id (str): Local file ID - media_type (str): The content type of the file - url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: - Deferred[dict]: Dict with "width" and "height" keys of original image + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: - return + return None input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) @@ -586,7 +601,7 @@ class MediaRepository(object): m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = await defer_to_thread( @@ -606,7 +621,7 @@ class MediaRepository(object): thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in iteritems(thumbnails): + for (t_width, t_height, t_type), t_method in thumbnails.items(): # Generate the thumbnail if t_method == "crop": t_byte_source = await defer_to_thread( @@ -705,7 +720,7 @@ class MediaRepositoryResource(Resource): Uploads are POSTed to a resource which returns a token which is used to GET the download:: - => POST /_matrix/media/v1/upload HTTP/1.1 + => POST /_matrix/media/r0/upload HTTP/1.1 Content-Type: <media-type> Content-Length: <content-length> @@ -716,7 +731,7 @@ class MediaRepositoryResource(Resource): { "content_uri": "mxc://<server-name>/<media-id>" } - => GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1 + => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 <= HTTP/1.1 200 OK Content-Type: <media-type> @@ -727,7 +742,7 @@ class MediaRepositoryResource(Resource): Clients can get thumbnails by supplying a desired width and height and thumbnailing method:: - => GET /_matrix/media/v1/thumbnail/<server_name> + => GET /_matrix/media/r0/thumbnail/<server_name> /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 <= HTTP/1.1 200 OK diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 683a79c966..3a352b5631 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -12,73 +12,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import logging import os import shutil -import sys - -import six +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence -from twisted.internet import defer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from ._base import Responder +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + + from .storage_provider import StorageProviderWrapper logger = logging.getLogger(__name__) -class MediaStorage(object): +class MediaStorage: """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProviderWrapper"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - @defer.inlineCallbacks - def store_file(self, source, file_info): + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers Args: source: A file like object that should be written - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Returns: - Deferred[str]: the file path written to in the primary media store + the file path written to in the primary media store """ with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) - yield finish_cb() + await finish_cb() return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. Actually yields a 3-tuple (file, fname, finish_cb), where file is a file like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns a Deferred. + on disk, and finish_cb is a function that returns an awaitable. fname can be used to read the contents from after upload, e.g. to generate thumbnails. @@ -88,13 +94,13 @@ class MediaStorage(object): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: with media_storage.store_into_file(info) as (f, fname, finish_cb): # .. write into f ... - yield finish_cb() + await finish_cb() """ path = self._file_info_to_path(file_info) @@ -106,10 +112,9 @@ class MediaStorage(object): finished_called = [False] - @defer.inlineCallbacks - def finish(): + async def finish(): for provider in self.storage_providers: - yield provider.store_file(path, file_info) + await provider.store_file(path, file_info) finished_called[0] = True @@ -117,27 +122,24 @@ class MediaStorage(object): with open(fname, "wb") as f: yield f, fname, finish except Exception: - t, v, tb = sys.exc_info() try: os.remove(fname) except Exception: pass - six.reraise(t, v, tb) + raise if not finished_called: raise Exception("Finished callback not called") - @defer.inlineCallbacks - def fetch_media(self, file_info): + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info (FileInfo) + file_info Returns: - Deferred[Responder|None]: Returns a Responder if the file was found, - otherwise None. + Returns a Responder if the file was found, otherwise None. """ path = self._file_info_to_path(file_info) @@ -146,23 +148,22 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = await provider.fetch(path, file_info) # type: Any if res: logger.debug("Streaming %s from %s", path, provider) return res return None - @defer.inlineCallbacks - def ensure_media_is_in_local_cache(self, file_info): + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: """Ensures that the given file is in the local cache. Attempts to download it from storage providers if it isn't. Args: - file_info (FileInfo) + file_info Returns: - Deferred[str]: Full path to local file + Full path to local file """ path = self._file_info_to_path(file_info) local_path = os.path.join(self.local_media_directory, path) @@ -174,29 +175,23 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = await provider.fetch(path, file_info) # type: Any if res: with res: consumer = BackgroundFileConsumer( open(local_path, "wb"), self.hs.get_reactor() ) - yield res.write_to_consumer(consumer) - yield consumer.wait() + await res.write_to_consumer(consumer) + await consumer.wait() return local_path raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f206605727..cd8c246594 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -24,28 +24,24 @@ import shutil import sys import traceback from typing import Dict, Optional +from urllib import parse as urlparse -import six -from six import string_types -from six.moves import urllib_parse as urlparse +import attr -from canonicaljson import json - -from twisted.internet import defer from twisted.internet.error import DNSLookupError from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient from synapse.http.server import ( - DirectServeResource, + DirectServeJsonResource, respond_with_json, respond_with_json_bytes, - wrap_json_request_handler, ) from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string @@ -60,8 +56,67 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) OG_TAG_NAME_MAXLEN = 50 OG_TAG_VALUE_MAXLEN = 1000 +ONE_HOUR = 60 * 60 * 1000 + +# A map of globs to API endpoints. +_oembed_globs = { + # Twitter. + "https://publish.twitter.com/oembed": [ + "https://twitter.com/*/status/*", + "https://*.twitter.com/*/status/*", + "https://twitter.com/*/moments/*", + "https://*.twitter.com/*/moments/*", + # Include the HTTP versions too. + "http://twitter.com/*/status/*", + "http://*.twitter.com/*/status/*", + "http://twitter.com/*/moments/*", + "http://*.twitter.com/*/moments/*", + ], +} +# Convert the globs to regular expressions. +_oembed_patterns = {} +for endpoint, globs in _oembed_globs.items(): + for glob in globs: + # Convert the glob into a sane regular expression to match against. The + # rules followed will be slightly different for the domain portion vs. + # the rest. + # + # 1. The scheme must be one of HTTP / HTTPS (and have no globs). + # 2. The domain can have globs, but we limit it to characters that can + # reasonably be a domain part. + # TODO: This does not attempt to handle Unicode domain names. + # 3. Other parts allow a glob to be any one, or more, characters. + results = urlparse.urlparse(glob) + + # Ensure the scheme does not have wildcards (and is a sane scheme). + if results.scheme not in {"http", "https"}: + raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,)) + + pattern = urlparse.urlunparse( + [ + results.scheme, + re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), + ] + + [re.escape(part).replace("\\*", ".+") for part in results[2:]] + ) + _oembed_patterns[re.compile(pattern)] = endpoint + + +@attr.s +class OEmbedResult: + # Either HTML content or URL must be provided. + html = attr.ib(type=Optional[str]) + url = attr.ib(type=Optional[str]) + title = attr.ib(type=Optional[str]) + # Number of seconds to cache the content. + cache_age = attr.ib(type=int) -class PreviewUrlResource(DirectServeResource): + +class OEmbedError(Exception): + """An error occurred processing the oEmbed object.""" + + +class PreviewUrlResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo, media_storage): @@ -85,6 +140,15 @@ class PreviewUrlResource(DirectServeResource): self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage + # We run the background jobs if we're the instance specified (or no + # instance is specified, where we assume there is only one instance + # serving media). + instance_running_jobs = hs.config.media.media_instance_running_background_jobs + self._worker_run_media_background_jobs = ( + instance_running_jobs is None + or instance_running_jobs == hs.get_instance_name() + ) + self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist self.url_preview_accept_language = hs.config.url_preview_accept_language @@ -94,18 +158,18 @@ class PreviewUrlResource(DirectServeResource): cache_name="url_previews", clock=self.clock, # don't spider URLs more often than once an hour - expiry_ms=60 * 60 * 1000, + expiry_ms=ONE_HOUR, ) - self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000 - ) + if self._worker_run_media_background_jobs: + self._cleaner_loop = self.clock.looping_call( + self._start_expire_url_cache_data, 10 * 1000 + ) - def render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request): request.setHeader(b"Allow", b"OPTIONS, GET") - return respond_with_json(request, 200, {}, send_cors=True) + respond_with_json(request, 200, {}, send_cors=True) - @wrap_json_request_handler async def _async_render_GET(self, request): # XXX: if get_user_by_req fails, what should we do in an async render? @@ -163,19 +227,19 @@ class PreviewUrlResource(DirectServeResource): else: logger.info("Returning cached response") - og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) + og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url, user, ts): + async def _do_preview(self, url: str, user: str, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: - url (str): - user (str): - ts (int): + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. Returns: - Deferred[bytes]: json-encoded og data + json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) @@ -188,7 +252,7 @@ class PreviewUrlResource(DirectServeResource): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] - if isinstance(og, six.text_type): + if isinstance(og, str): og = og.encode("utf8") return og @@ -290,7 +354,7 @@ class PreviewUrlResource(DirectServeResource): logger.debug("Calculated OG for %s as %s", url, og) - jsonog = json.dumps(og) + jsonog = json_encoder.encode(og) # store OG in history-aware DB cache await self.store.store_url_cache( @@ -305,6 +369,87 @@ class PreviewUrlResource(DirectServeResource): return jsonog.encode("utf8") + def _get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Params: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in _oembed_patterns.items(): + if url_pattern.fullmatch(url): + return endpoint + + # No match. + return None + + async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: + """ + Request content from an oEmbed endpoint. + + Params: + endpoint: The oEmbed API endpoint. + url: The URL to pass to the API. + + Returns: + An object representing the metadata returned. + + Raises: + OEmbedError if fetching or parsing of the oEmbed information fails. + """ + try: + logger.debug("Trying to get oEmbed content for url '%s'", url) + result = await self.client.get_json( + endpoint, + # TODO Specify max height / width. + # Note that only the JSON format is supported. + args={"url": url}, + ) + + # Ensure there's a version of 1.0. + if result.get("version") != "1.0": + raise OEmbedError("Invalid version: %s" % (result.get("version"),)) + + oembed_type = result.get("type") + + # Ensure the cache age is None or an int. + cache_age = result.get("cache_age") + if cache_age: + cache_age = int(cache_age) + + oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) + + # HTML content. + if oembed_type == "rich": + oembed_result.html = result.get("html") + return oembed_result + + if oembed_type == "photo": + oembed_result.url = result.get("url") + return oembed_result + + # TODO Handle link and video types. + + if "thumbnail_url" in result: + oembed_result.url = result.get("thumbnail_url") + return oembed_result + + raise OEmbedError("Incompatible oEmbed information.") + + except OEmbedError as e: + # Trap OEmbedErrors first so we can directly re-raise them. + logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) + raise + + except Exception as e: + # Trap any exception and let the code follow as usual. + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) + raise OEmbedError() from e + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a @@ -314,54 +459,90 @@ class PreviewUrlResource(DirectServeResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) - with self.media_storage.store_into_file(file_info) as (f, fname, finish): + # If this URL can be accessed via oEmbed, use that instead. + url_to_download = url + oembed_url = self._get_oembed_url(url) + if oembed_url: + # The result might be a new URL to download, or it might be HTML content. try: - logger.debug("Trying to get preview for url '%s'", url) - length, headers, uri, code = await self.client.get_file( - url, - output_stream=f, - max_size=self.max_spider_size, - headers={"Accept-Language": self.url_preview_accept_language}, - ) - except SynapseError: - # Pass SynapseErrors through directly, so that the servlet - # handler will return a SynapseError to the client instead of - # blank data or a 500. - raise - except DNSLookupError: - # DNS lookup returned no results - # Note: This will also be the case if one of the resolved IP - # addresses is blacklisted - raise SynapseError( - 502, - "DNS resolution failure during URL preview generation", - Codes.UNKNOWN, - ) - except Exception as e: - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading %s: %r", url, e) + oembed_result = await self._get_oembed_content(oembed_url, url) + if oembed_result.url: + url_to_download = oembed_result.url + elif oembed_result.html: + url_to_download = None + except OEmbedError: + # If an error occurs, try doing a normal preview. + pass - raise SynapseError( - 500, - "Failed to download content: %s" - % (traceback.format_exception_only(sys.exc_info()[0], e),), - Codes.UNKNOWN, - ) - await finish() + if url_to_download: + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + logger.debug("Trying to get preview for url '%s'", url_to_download) + length, headers, uri, code = await self.client.get_file( + url_to_download, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, + ) + except SynapseError: + # Pass SynapseErrors through directly, so that the servlet + # handler will return a SynapseError to the client instead of + # blank data or a 500. + raise + except DNSLookupError: + # DNS lookup returned no results + # Note: This will also be the case if one of the resolved IP + # addresses is blacklisted + raise SynapseError( + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, + ) + except Exception as e: + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading %s: %r", url_to_download, e) + + raise SynapseError( + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), + Codes.UNKNOWN, + ) + await finish() + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + + download_name = get_filename_from_headers(headers) + + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + expires = ONE_HOUR + etag = headers["ETag"][0] if "ETag" in headers else None + else: + html_bytes = oembed_result.html.encode("utf-8") # type: ignore + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + f.write(html_bytes) + await finish() + + media_type = "text/html" + download_name = oembed_result.title + length = len(html_bytes) + # If a specific cache age was not given, assume 1 hour. + expires = oembed_result.cache_age or ONE_HOUR + uri = oembed_url + code = 200 + etag = None try: - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() - download_name = get_filename_from_headers(headers) - await self.store.store_local_media( media_id=file_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=download_name, media_length=length, user_id=user, @@ -384,10 +565,8 @@ class PreviewUrlResource(DirectServeResource): "filename": fname, "uri": uri, "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, + "expires": expires, + "etag": etag, } def _start_expire_url_cache_data(self): @@ -400,11 +579,13 @@ class PreviewUrlResource(DirectServeResource): """ # TODO: Delete from backup media store + assert self._worker_run_media_background_jobs + now = self.clock.time_msec() logger.debug("Running url preview cache expiry") - if not (await self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db_pool.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return @@ -442,7 +623,7 @@ class PreviewUrlResource(DirectServeResource): # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. - expire_before = now - 2 * 24 * 60 * 60 * 1000 + expire_before = now - 2 * 24 * ONE_HOUR media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] @@ -631,7 +812,7 @@ def _iterate_over_text(tree, *tags_to_ignore): if el is None: return - if isinstance(el, string_types): + if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 858680be26..18c9ed48d6 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -13,65 +13,66 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import os import shutil - -from twisted.internet import defer +from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) -class StorageProvider(object): +class StorageProvider: """A storage provider is a service that can store uploaded media and retrieve them. """ - def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) - - Returns: - Deferred + path: Relative path of file in local cache + file_info: The metadata of the file. """ - pass - def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) + path: Relative path of file in local cache + file_info: The metadata of the file. Returns: - Deferred(Responder): Returns a Responder if the provider has the file, - otherwise returns None. + Returns a Responder if the provider has the file, otherwise returns None. """ - pass class StorageProviderWrapper(StorageProvider): """Wraps a storage provider and provides various config options Args: - backend (StorageProvider) - store_local (bool): Whether to store new local files or not. - store_synchronous (bool): Whether to wait for file to be successfully + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully uploaded, or todo the upload in the background. - store_remote (bool): Whether remote media should be uploaded + store_remote: Whether remote media should be uploaded """ - def __init__(self, backend, store_local, store_synchronous, store_remote): + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous @@ -80,28 +81,38 @@ class StorageProviderWrapper(StorageProvider): def __str__(self): return "StorageProviderWrapper[%s]" % (self.backend,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: - return defer.succeed(None) + return None if file_info.server_name and not self.store_remote: - return defer.succeed(None) + return None if self.store_synchronous: - return self.backend.store_file(path, file_info) + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = self.backend.store_file(path, file_info) + if inspect.isawaitable(result): + return await result else: # TODO: Handle errors. - def store(): + async def store(): try: - return self.backend.store_file(path, file_info) + result = self.backend.store_file(path, file_info) + if inspect.isawaitable(result): + return await result except Exception: logger.exception("Error storing file") run_in_background(store) - return defer.succeed(None) + return None - def fetch(self, path, file_info): - return self.backend.fetch(path, file_info) + async def fetch(self, path, file_info): + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = self.backend.fetch(path, file_info) + if inspect.isawaitable(result): + return await result class FileStorageProviderBackend(StorageProvider): @@ -120,7 +131,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -130,11 +141,11 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return defer_to_thread( + return await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 0b87220234..a83535b97b 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,11 +16,7 @@ import logging -from synapse.http.server import ( - DirectServeResource, - set_cors_headers, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string from ._base import ( @@ -34,7 +30,7 @@ from ._base import ( logger = logging.getLogger(__name__) -class ThumbnailResource(DirectServeResource): +class ThumbnailResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo, media_storage): @@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource): self.media_storage = media_storage self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - self.clock = hs.get_clock() - @wrap_json_request_handler async def _async_render_GET(self, request): set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index c234ea7421..d681bf7bf0 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from io import BytesIO -import PIL.Image as Image +from PIL import Image as Image logger = logging.getLogger(__name__) @@ -32,7 +31,7 @@ EXIF_TRANSPOSE_MAPPINGS = { } -class Thumbnailer(object): +class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 83d005812d..3ebf7a68e6 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -15,20 +15,14 @@ import logging -from twisted.web.server import NOT_DONE_YET - from synapse.api.errors import Codes, SynapseError -from synapse.http.server import ( - DirectServeResource, - respond_with_json, - wrap_json_request_handler, -) +from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string logger = logging.getLogger(__name__) -class UploadResource(DirectServeResource): +class UploadResource(DirectServeJsonResource): isLeaf = True def __init__(self, hs, media_repo): @@ -43,11 +37,9 @@ class UploadResource(DirectServeResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - def render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request): respond_with_json(request, 200, {}, send_cors=True) - return NOT_DONE_YET - @wrap_json_request_handler async def _async_render_POST(self, request): requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py index c03194f001..f7a0bc4bdb 100644 --- a/synapse/rest/oidc/callback_resource.py +++ b/synapse/rest/oidc/callback_resource.py @@ -14,18 +14,17 @@ # limitations under the License. import logging -from synapse.http.server import DirectServeResource, wrap_html_request_handler +from synapse.http.server import DirectServeHtmlResource logger = logging.getLogger(__name__) -class OIDCCallbackResource(DirectServeResource): +class OIDCCallbackResource(DirectServeHtmlResource): isLeaf = 1 def __init__(self, hs): super().__init__() self._oidc_handler = hs.get_oidc_handler() - @wrap_html_request_handler async def _async_render_GET(self, request): - return await self._oidc_handler.handle_oidc_callback(request) + await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 75e58043b4..c10188a5d7 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -16,10 +16,10 @@ from twisted.python import failure from synapse.api.errors import SynapseError -from synapse.http.server import DirectServeResource, return_html_error +from synapse.http.server import DirectServeHtmlResource, return_html_error -class SAML2ResponseResource(DirectServeResource): +class SAML2ResponseResource(DirectServeHtmlResource): """A Twisted web resource which handles the SAML response""" isLeaf = 1 diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 20177b44e7..f591cc6c5c 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -13,17 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from twisted.web.resource import Resource from synapse.http.server import set_cors_headers +from synapse.util import json_encoder logger = logging.getLogger(__name__) -class WellKnownBuilder(object): +class WellKnownBuilder: """Utility to construct the well-known response Args: @@ -67,4 +67,4 @@ class WellKnownResource(Resource): logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") - return json.dumps(r).encode("utf-8") + return json_encoder.encode(r).encode("utf-8") diff --git a/synapse/secrets.py b/synapse/secrets.py index 0b327a0f82..fb6d90a3b7 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -19,22 +19,25 @@ Injectable secrets module for Synapse. See https://docs.python.org/3/library/secrets.html#module-secrets for the API used in Python 3.6, and the API emulated in Python 2.7. """ - import sys # secrets is available since python 3.6 if sys.version_info[0:2] >= (3, 6): import secrets - def Secrets(): - return secrets + class Secrets: + def token_bytes(self, nbytes=32): + return secrets.token_bytes(nbytes) + + def token_hex(self, nbytes=32): + return secrets.token_hex(nbytes) else: - import os import binascii + import os - class Secrets(object): + class Secrets: def token_bytes(self, nbytes=32): return os.urandom(nbytes) diff --git a/synapse/server.py b/synapse/server.py index fe94836a2c..9055b97ac3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -22,10 +22,14 @@ # Imports required for the default HomeServer() implementation import abc +import functools import logging import os +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +import twisted from twisted.mail.smtp import sendmail +from twisted.web.iweb import IPolicyForHTTPS from synapse.api.auth import Auth from synapse.api.filtering import Filtering @@ -44,7 +48,6 @@ from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, FederationServer, - ReplicationFederationHandlerRegistry, ) from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import FederationSender @@ -73,14 +76,18 @@ from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler -from synapse.handlers.room import RoomContextHandler, RoomCreationHandler +from synapse.handlers.room import ( + RoomContextHandler, + RoomCreationHandler, + RoomShutdownHandler, +) from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler -from synapse.handlers.typing import TypingHandler +from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient @@ -90,7 +97,7 @@ from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -102,32 +109,74 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import DataStores, Storage +from synapse.storage import Databases, DataStore, Storage from synapse.streams.events import EventSources +from synapse.types import DomainSpecificString from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.handlers.oidc_handler import OidcHandler + from synapse.handlers.saml_handler import SamlHandler + + +T = TypeVar("T", bound=Callable[..., Any]) + + +def cache_in_self(builder: T) -> T: + """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and + returning if so. If not, calls the given function and sets `self.foo` to it. + + Also ensures that dependency cycles throw an exception correctly, rather + than overflowing the stack. + """ + + if not builder.__name__.startswith("get_"): + raise Exception( + "@cache_in_self can only be used on functions starting with `get_`" + ) + + depname = builder.__name__[len("get_") :] + + building = [False] + + @functools.wraps(builder) + def _get(self): + try: + return getattr(self, depname) + except AttributeError: + pass + + # Prevent cyclic dependencies from deadlocking + if building[0]: + raise ValueError("Cyclic dependency while building %s" % (depname,)) -class HomeServer(object): + building[0] = True + try: + dep = builder(self) + setattr(self, depname, dep) + finally: + building[0] = False + + return dep + + # We cast here as we need to tell mypy that `_get` has the same signature as + # `builder`. + return cast(T, _get) + + +class HomeServer(metaclass=abc.ABCMeta): """A basic homeserver object without lazy component builders. This will need all of the components it requires to either be passed as constructor arguments, or the relevant methods overriding to create them. Typically this would only be used for unit tests. - For every dependency in the DEPENDENCIES list below, this class creates one - method, - def get_DEPENDENCY(self) - which returns the value of that dependency. If no value has yet been set - nor was provided to the constructor, it will attempt to call a lazy builder - method called - def build_DEPENDENCY(self) - which must be implemented by the subclass. This code may call any of the - required "get" methods on the instance to obtain the sub-dependencies that - one requires. + Dependencies should be added by creating a `def get_<depname>(self)` + function, wrapping it in `@cache_in_self`. Attributes: config (synapse.config.homeserver.HomeserverConfig): @@ -135,85 +184,6 @@ class HomeServer(object): we are listening on to provide HTTP services. """ - __metaclass__ = abc.ABCMeta - - DEPENDENCIES = [ - "http_client", - "federation_client", - "federation_server", - "handlers", - "auth", - "room_creation_handler", - "state_handler", - "state_resolution_handler", - "presence_handler", - "sync_handler", - "typing_handler", - "room_list_handler", - "acme_handler", - "auth_handler", - "device_handler", - "stats_handler", - "e2e_keys_handler", - "e2e_room_keys_handler", - "event_handler", - "event_stream_handler", - "initial_sync_handler", - "application_service_api", - "application_service_scheduler", - "application_service_handler", - "device_message_handler", - "profile_handler", - "event_creation_handler", - "deactivate_account_handler", - "set_password_handler", - "notifier", - "event_sources", - "keyring", - "pusherpool", - "event_builder_factory", - "filtering", - "http_client_context_factory", - "simple_http_client", - "proxied_http_client", - "media_repository", - "media_repository_resource", - "federation_transport_client", - "federation_sender", - "receipts_handler", - "macaroon_generator", - "tcp_replication", - "read_marker_handler", - "action_generator", - "user_directory_handler", - "groups_local_handler", - "groups_server_handler", - "groups_attestation_signing", - "groups_attestation_renewer", - "secrets", - "spam_checker", - "third_party_event_rules", - "room_member_handler", - "federation_registry", - "server_notices_manager", - "server_notices_sender", - "message_handler", - "pagination_handler", - "room_context_handler", - "sendmail", - "registration_handler", - "account_validity_handler", - "cas_handler", - "saml_handler", - "oidc_handler", - "event_client_serializer", - "password_policy_handler", - "storage", - "replication_streamer", - "replication_data_handler", - "replication_streams", - ] - REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] # This is overridden in derived application classes @@ -228,14 +198,17 @@ class HomeServer(object): config: The full config for the homeserver. """ if not reactor: - from twisted.internet import reactor + from twisted.internet import reactor as _reactor + + reactor = _reactor self._reactor = reactor self.hostname = hostname + # the key we use to sign events and requests + self.signing_key = config.key.signing_key[0] self.config = config - self._building = {} - self._listening_services = [] - self.start_time = None + self._listening_services = [] # type: List[twisted.internet.tcp.Port] + self.start_time = None # type: Optional[int] self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" @@ -249,13 +222,13 @@ class HomeServer(object): burst_count=config.rc_registration.burst_count, ) - self.datastores = None + self.datastores = None # type: Optional[Databases] # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) - def get_instance_id(self): + def get_instance_id(self) -> str: """A unique ID for this synapse process instance. This is used to distinguish running instances in worker-based @@ -271,13 +244,13 @@ class HomeServer(object): """ return self._instance_name - def setup(self): + def setup(self) -> None: logger.info("Setting up.") self.start_time = int(self.get_clock().time()) - self.datastores = DataStores(self.DATASTORE_CLASS, self) + self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def setup_master(self): + def setup_master(self) -> None: """ Some handlers have side effects on instantiation (like registering background updates). This function causes them to be fetched, and @@ -286,186 +259,242 @@ class HomeServer(object): for i in self.REQUIRED_ON_MASTER_STARTUP: getattr(self, "get_" + i)() - def get_reactor(self): + def get_reactor(self) -> twisted.internet.base.ReactorBase: """ Fetch the Twisted reactor in use by this HomeServer. """ return self._reactor - def get_ip_from_request(self, request): + def get_ip_from_request(self, request) -> str: # X-Forwarded-For is handled by our custom request type. return request.getClientIP() - def is_mine(self, domain_specific_string): + def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: return domain_specific_string.domain == self.hostname - def is_mine_id(self, string): + def is_mine_id(self, string: str) -> bool: return string.split(":", 1)[1] == self.hostname - def get_clock(self): + def get_clock(self) -> Clock: return self.clock - def get_datastore(self): + def get_datastore(self) -> DataStore: + if not self.datastores: + raise Exception("HomeServer.setup must be called before getting datastores") + return self.datastores.main - def get_datastores(self): + def get_datastores(self) -> Databases: + if not self.datastores: + raise Exception("HomeServer.setup must be called before getting datastores") + return self.datastores - def get_config(self): + def get_config(self) -> HomeServerConfig: return self.config - def get_distributor(self): + def get_distributor(self) -> Distributor: return self.distributor def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def build_federation_client(self): + @cache_in_self + def get_federation_client(self) -> FederationClient: return FederationClient(self) - def build_federation_server(self): + @cache_in_self + def get_federation_server(self) -> FederationServer: return FederationServer(self) - def build_handlers(self): + @cache_in_self + def get_handlers(self) -> Handlers: return Handlers(self) - def build_notifier(self): + @cache_in_self + def get_notifier(self) -> Notifier: return Notifier(self) - def build_auth(self): + @cache_in_self + def get_auth(self) -> Auth: return Auth(self) - def build_http_client_context_factory(self): + @cache_in_self + def get_http_client_context_factory(self) -> IPolicyForHTTPS: return ( InsecureInterceptableContextFactory() if self.config.use_insecure_ssl_client_just_for_testing_do_not_use else RegularPolicyForHTTPS() ) - def build_simple_http_client(self): + @cache_in_self + def get_simple_http_client(self) -> SimpleHttpClient: return SimpleHttpClient(self) - def build_proxied_http_client(self): + @cache_in_self + def get_proxied_http_client(self) -> SimpleHttpClient: return SimpleHttpClient( self, http_proxy=os.getenvb(b"http_proxy"), https_proxy=os.getenvb(b"HTTPS_PROXY"), ) - def build_room_creation_handler(self): + @cache_in_self + def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) - def build_sendmail(self): + @cache_in_self + def get_room_shutdown_handler(self) -> RoomShutdownHandler: + return RoomShutdownHandler(self) + + @cache_in_self + def get_sendmail(self) -> sendmail: return sendmail - def build_state_handler(self): + @cache_in_self + def get_state_handler(self) -> StateHandler: return StateHandler(self) - def build_state_resolution_handler(self): + @cache_in_self + def get_state_resolution_handler(self) -> StateResolutionHandler: return StateResolutionHandler(self) - def build_presence_handler(self): + @cache_in_self + def get_presence_handler(self) -> PresenceHandler: return PresenceHandler(self) - def build_typing_handler(self): - return TypingHandler(self) + @cache_in_self + def get_typing_handler(self): + if self.config.worker.writers.typing == self.get_instance_name(): + return TypingWriterHandler(self) + else: + return FollowerTypingHandler(self) - def build_sync_handler(self): + @cache_in_self + def get_sync_handler(self) -> SyncHandler: return SyncHandler(self) - def build_room_list_handler(self): + @cache_in_self + def get_room_list_handler(self) -> RoomListHandler: return RoomListHandler(self) - def build_auth_handler(self): + @cache_in_self + def get_auth_handler(self) -> AuthHandler: return AuthHandler(self) - def build_macaroon_generator(self): + @cache_in_self + def get_macaroon_generator(self) -> MacaroonGenerator: return MacaroonGenerator(self) - def build_device_handler(self): + @cache_in_self + def get_device_handler(self): if self.config.worker_app: return DeviceWorkerHandler(self) else: return DeviceHandler(self) - def build_device_message_handler(self): + @cache_in_self + def get_device_message_handler(self) -> DeviceMessageHandler: return DeviceMessageHandler(self) - def build_e2e_keys_handler(self): + @cache_in_self + def get_e2e_keys_handler(self) -> E2eKeysHandler: return E2eKeysHandler(self) - def build_e2e_room_keys_handler(self): + @cache_in_self + def get_e2e_room_keys_handler(self) -> E2eRoomKeysHandler: return E2eRoomKeysHandler(self) - def build_acme_handler(self): + @cache_in_self + def get_acme_handler(self) -> AcmeHandler: return AcmeHandler(self) - def build_application_service_api(self): + @cache_in_self + def get_application_service_api(self) -> ApplicationServiceApi: return ApplicationServiceApi(self) - def build_application_service_scheduler(self): + @cache_in_self + def get_application_service_scheduler(self) -> ApplicationServiceScheduler: return ApplicationServiceScheduler(self) - def build_application_service_handler(self): + @cache_in_self + def get_application_service_handler(self) -> ApplicationServicesHandler: return ApplicationServicesHandler(self) - def build_event_handler(self): + @cache_in_self + def get_event_handler(self) -> EventHandler: return EventHandler(self) - def build_event_stream_handler(self): + @cache_in_self + def get_event_stream_handler(self) -> EventStreamHandler: return EventStreamHandler(self) - def build_initial_sync_handler(self): + @cache_in_self + def get_initial_sync_handler(self) -> InitialSyncHandler: return InitialSyncHandler(self) - def build_profile_handler(self): + @cache_in_self + def get_profile_handler(self): if self.config.worker_app: return BaseProfileHandler(self) else: return MasterProfileHandler(self) - def build_event_creation_handler(self): + @cache_in_self + def get_event_creation_handler(self) -> EventCreationHandler: return EventCreationHandler(self) - def build_deactivate_account_handler(self): + @cache_in_self + def get_deactivate_account_handler(self) -> DeactivateAccountHandler: return DeactivateAccountHandler(self) - def build_set_password_handler(self): + @cache_in_self + def get_set_password_handler(self) -> SetPasswordHandler: return SetPasswordHandler(self) - def build_event_sources(self): + @cache_in_self + def get_event_sources(self) -> EventSources: return EventSources(self) - def build_keyring(self): + @cache_in_self + def get_keyring(self) -> Keyring: return Keyring(self) - def build_event_builder_factory(self): + @cache_in_self + def get_event_builder_factory(self) -> EventBuilderFactory: return EventBuilderFactory(self) - def build_filtering(self): + @cache_in_self + def get_filtering(self) -> Filtering: return Filtering(self) - def build_pusherpool(self): + @cache_in_self + def get_pusherpool(self) -> PusherPool: return PusherPool(self) - def build_http_client(self): + @cache_in_self + def get_http_client(self) -> MatrixFederationHttpClient: tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_media_repository_resource(self): + @cache_in_self + def get_media_repository_resource(self) -> MediaRepositoryResource: # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of return MediaRepositoryResource(self) - def build_media_repository(self): + @cache_in_self + def get_media_repository(self) -> MediaRepository: return MediaRepository(self) - def build_federation_transport_client(self): + @cache_in_self + def get_federation_transport_client(self) -> TransportLayerClient: return TransportLayerClient(self) - def build_federation_sender(self): + @cache_in_self + def get_federation_sender(self): if self.should_send_federation(): return FederationSender(self) elif not self.config.worker_app: @@ -473,159 +502,152 @@ class HomeServer(object): else: raise Exception("Workers cannot send federation traffic") - def build_receipts_handler(self): + @cache_in_self + def get_receipts_handler(self) -> ReceiptsHandler: return ReceiptsHandler(self) - def build_read_marker_handler(self): + @cache_in_self + def get_read_marker_handler(self) -> ReadMarkerHandler: return ReadMarkerHandler(self) - def build_tcp_replication(self): + @cache_in_self + def get_tcp_replication(self) -> ReplicationCommandHandler: return ReplicationCommandHandler(self) - def build_action_generator(self): + @cache_in_self + def get_action_generator(self) -> ActionGenerator: return ActionGenerator(self) - def build_user_directory_handler(self): + @cache_in_self + def get_user_directory_handler(self) -> UserDirectoryHandler: return UserDirectoryHandler(self) - def build_groups_local_handler(self): + @cache_in_self + def get_groups_local_handler(self): if self.config.worker_app: return GroupsLocalWorkerHandler(self) else: return GroupsLocalHandler(self) - def build_groups_server_handler(self): + @cache_in_self + def get_groups_server_handler(self): if self.config.worker_app: return GroupsServerWorkerHandler(self) else: return GroupsServerHandler(self) - def build_groups_attestation_signing(self): + @cache_in_self + def get_groups_attestation_signing(self) -> GroupAttestationSigning: return GroupAttestationSigning(self) - def build_groups_attestation_renewer(self): + @cache_in_self + def get_groups_attestation_renewer(self) -> GroupAttestionRenewer: return GroupAttestionRenewer(self) - def build_secrets(self): + @cache_in_self + def get_secrets(self) -> Secrets: return Secrets() - def build_stats_handler(self): + @cache_in_self + def get_stats_handler(self) -> StatsHandler: return StatsHandler(self) - def build_spam_checker(self): + @cache_in_self + def get_spam_checker(self): return SpamChecker(self) - def build_third_party_event_rules(self): + @cache_in_self + def get_third_party_event_rules(self) -> ThirdPartyEventRules: return ThirdPartyEventRules(self) - def build_room_member_handler(self): + @cache_in_self + def get_room_member_handler(self): if self.config.worker_app: return RoomMemberWorkerHandler(self) return RoomMemberMasterHandler(self) - def build_federation_registry(self): - if self.config.worker_app: - return ReplicationFederationHandlerRegistry(self) - else: - return FederationHandlerRegistry() + @cache_in_self + def get_federation_registry(self) -> FederationHandlerRegistry: + return FederationHandlerRegistry(self) - def build_server_notices_manager(self): + @cache_in_self + def get_server_notices_manager(self): if self.config.worker_app: raise Exception("Workers cannot send server notices") return ServerNoticesManager(self) - def build_server_notices_sender(self): + @cache_in_self + def get_server_notices_sender(self): if self.config.worker_app: return WorkerServerNoticesSender(self) return ServerNoticesSender(self) - def build_message_handler(self): + @cache_in_self + def get_message_handler(self) -> MessageHandler: return MessageHandler(self) - def build_pagination_handler(self): + @cache_in_self + def get_pagination_handler(self) -> PaginationHandler: return PaginationHandler(self) - def build_room_context_handler(self): + @cache_in_self + def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) - def build_registration_handler(self): + @cache_in_self + def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) - def build_account_validity_handler(self): + @cache_in_self + def get_account_validity_handler(self) -> AccountValidityHandler: return AccountValidityHandler(self) - def build_cas_handler(self): + @cache_in_self + def get_cas_handler(self) -> CasHandler: return CasHandler(self) - def build_saml_handler(self): + @cache_in_self + def get_saml_handler(self) -> "SamlHandler": from synapse.handlers.saml_handler import SamlHandler return SamlHandler(self) - def build_oidc_handler(self): + @cache_in_self + def get_oidc_handler(self) -> "OidcHandler": from synapse.handlers.oidc_handler import OidcHandler return OidcHandler(self) - def build_event_client_serializer(self): + @cache_in_self + def get_event_client_serializer(self) -> EventClientSerializer: return EventClientSerializer(self) - def build_password_policy_handler(self): + @cache_in_self + def get_password_policy_handler(self) -> PasswordPolicyHandler: return PasswordPolicyHandler(self) - def build_storage(self) -> Storage: - return Storage(self, self.datastores) + @cache_in_self + def get_storage(self) -> Storage: + return Storage(self, self.get_datastores()) - def build_replication_streamer(self) -> ReplicationStreamer: + @cache_in_self + def get_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) - def build_replication_data_handler(self): + @cache_in_self + def get_replication_data_handler(self) -> ReplicationDataHandler: return ReplicationDataHandler(self) - def build_replication_streams(self): + @cache_in_self + def get_replication_streams(self) -> Dict[str, Stream]: return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} - def remove_pusher(self, app_id, push_key, user_id): - return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): + return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) - def should_send_federation(self): + def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" return self.config.send_federation and ( not self.config.worker_app or self.config.worker_app == "synapse.app.federation_sender" ) - - -def _make_dependency_method(depname): - def _get(hs): - try: - return getattr(hs, depname) - except AttributeError: - pass - - try: - builder = getattr(hs, "build_%s" % (depname)) - except AttributeError: - raise NotImplementedError( - "%s has no %s nor a builder for it" % (type(hs).__name__, depname) - ) - - # Prevent cyclic dependencies from deadlocking - if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % (depname,)) - - hs._building[depname] = 1 - try: - dep = builder() - setattr(hs, depname, dep) - finally: - del hs._building[depname] - - return dep - - setattr(HomeServer, "get_%s" % (depname), _get) - - -# Build magic accessors for every dependency -for depname in HomeServer.DEPENDENCIES: - _make_dependency_method(depname) diff --git a/synapse/server.pyi b/synapse/server.pyi deleted file mode 100644 index fe8024d2d4..0000000000 --- a/synapse/server.pyi +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Dict - -import twisted.internet - -import synapse.api.auth -import synapse.config.homeserver -import synapse.crypto.keyring -import synapse.federation.federation_server -import synapse.federation.sender -import synapse.federation.transport.client -import synapse.handlers -import synapse.handlers.auth -import synapse.handlers.deactivate_account -import synapse.handlers.device -import synapse.handlers.e2e_keys -import synapse.handlers.message -import synapse.handlers.presence -import synapse.handlers.register -import synapse.handlers.room -import synapse.handlers.room_member -import synapse.handlers.set_password -import synapse.http.client -import synapse.notifier -import synapse.push.pusherpool -import synapse.replication.tcp.client -import synapse.replication.tcp.handler -import synapse.rest.media.v1.media_repository -import synapse.server_notices.server_notices_manager -import synapse.server_notices.server_notices_sender -import synapse.state -import synapse.storage -from synapse.events.builder import EventBuilderFactory -from synapse.replication.tcp.streams import Stream - -class HomeServer(object): - @property - def config(self) -> synapse.config.homeserver.HomeServerConfig: - pass - @property - def hostname(self) -> str: - pass - def get_auth(self) -> synapse.api.auth.Auth: - pass - def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: - pass - def get_datastore(self) -> synapse.storage.DataStore: - pass - def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: - pass - def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: - pass - def get_handlers(self) -> synapse.handlers.Handlers: - pass - def get_state_handler(self) -> synapse.state.StateHandler: - pass - def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: - pass - def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: - """Fetch an HTTP client implementation which doesn't do any blacklisting - or support any HTTP_PROXY settings""" - pass - def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: - """Fetch an HTTP client implementation which doesn't do any blacklisting - but does support HTTP_PROXY settings""" - pass - def get_deactivate_account_handler( - self, - ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: - pass - def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: - pass - def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: - pass - def get_event_creation_handler( - self, - ) -> synapse.handlers.message.EventCreationHandler: - pass - def get_set_password_handler( - self, - ) -> synapse.handlers.set_password.SetPasswordHandler: - pass - def get_federation_sender(self) -> synapse.federation.sender.FederationSender: - pass - def get_federation_transport_client( - self, - ) -> synapse.federation.transport.client.TransportLayerClient: - pass - def get_media_repository_resource( - self, - ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: - pass - def get_media_repository( - self, - ) -> synapse.rest.media.v1.media_repository.MediaRepository: - pass - def get_server_notices_manager( - self, - ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: - pass - def get_server_notices_sender( - self, - ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: - pass - def get_notifier(self) -> synapse.notifier.Notifier: - pass - def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler: - pass - def get_clock(self) -> synapse.util.Clock: - pass - def get_reactor(self) -> twisted.internet.base.ReactorBase: - pass - def get_keyring(self) -> synapse.crypto.keyring.Keyring: - pass - def get_tcp_replication( - self, - ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: - pass - def get_replication_data_handler( - self, - ) -> synapse.replication.tcp.client.ReplicationDataHandler: - pass - def get_federation_registry( - self, - ) -> synapse.federation.federation_server.FederationHandlerRegistry: - pass - def is_mine_id(self, domain_id: str) -> bool: - pass - def get_instance_id(self) -> str: - pass - def get_instance_name(self) -> str: - pass - def get_event_builder_factory(self) -> EventBuilderFactory: - pass - def get_storage(self) -> synapse.storage.Storage: - pass - def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler: - pass - def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator: - pass - def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: - pass - def get_replication_streams(self) -> Dict[str, Stream]: - pass diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 3bf330da49..3673e7f47e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six import iteritems, string_types +from typing import Any from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder @@ -24,7 +23,7 @@ from synapse.types import get_localpart_from_id logger = logging.getLogger(__name__) -class ConsentServerNotices(object): +class ConsentServerNotices: """Keeps track of whether we need to send users server_notices about privacy policy consent, and sends one if we do. """ @@ -57,14 +56,11 @@ class ConsentServerNotices(object): self._consent_uri_builder = ConsentURIBuilder(hs.config) - async def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id: str) -> None: """Check if we need to send a notice to this user, and does so if so Args: - user_id (str): user to check - - Returns: - Deferred + user_id: user to check """ if self._server_notice_content is None: # not enabled @@ -107,7 +103,7 @@ class ConsentServerNotices(object): self._users_in_progress.remove(user_id) -def copy_with_str_subst(x, substitutions): +def copy_with_str_subst(x: Any, substitutions: Any) -> Any: """Deep-copy a structure, carrying out string substitions on any strings Args: @@ -118,12 +114,12 @@ def copy_with_str_subst(x, substitutions): Returns: copy of x """ - if isinstance(x, string_types): + if isinstance(x, str): return x % substitutions if isinstance(x, dict): - return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)} + return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()} if isinstance(x, (list, tuple)): - return [copy_with_str_subst(y) for y in x] + return [copy_with_str_subst(y, substitutions) for y in x] # assume it's uninterested and can be shallow-copied. return x diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 73f2cedb5c..2258d306d9 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six import iteritems +from typing import List, Tuple from synapse.api.constants import ( EventTypes, @@ -28,7 +27,7 @@ from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG logger = logging.getLogger(__name__) -class ResourceLimitsServerNotices(object): +class ResourceLimitsServerNotices: """ Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ @@ -54,7 +53,7 @@ class ResourceLimitsServerNotices(object): and not hs.config.hs_disabled ) - async def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id: str) -> None: """Check if we need to send a notice to this user, this will be true in two cases. 1. The server has reached its limit does not reflect this @@ -62,10 +61,7 @@ class ResourceLimitsServerNotices(object): actually the server is fine Args: - user_id (str): user to check - - Returns: - Deferred + user_id: user to check """ if not self._enabled: return @@ -117,19 +113,21 @@ class ResourceLimitsServerNotices(object): elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. await self._apply_limit_block_notification( - user_id, limit_msg, limit_type + user_id, limit_msg, limit_type # type: ignore ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) - async def _remove_limit_block_notification(self, user_id, ref_events): + async def _remove_limit_block_notification( + self, user_id: str, ref_events: List[str] + ) -> None: """Utility method to remove limit block notifications from the server notices room. Args: - user_id (str): user to notify - ref_events (list[str]): The event_ids of pinned events that are unrelated to - limit blocking and need to be preserved. + user_id: user to notify + ref_events: The event_ids of pinned events that are unrelated to + limit blocking and need to be preserved. """ content = {"pinned": ref_events} await self._server_notices_manager.send_notice( @@ -137,16 +135,16 @@ class ResourceLimitsServerNotices(object): ) async def _apply_limit_block_notification( - self, user_id, event_body, event_limit_type - ): + self, user_id: str, event_body: str, event_limit_type: str + ) -> None: """Utility method to apply limit block notifications in the server notices room. Args: - user_id (str): user to notify - event_body(str): The human readable text that describes the block. - event_limit_type(str): Specifies the type of block e.g. monthly active user - limit has been exceeded. + user_id: user to notify + event_body: The human readable text that describes the block. + event_limit_type: Specifies the type of block e.g. monthly active user + limit has been exceeded. """ content = { "body": event_body, @@ -164,7 +162,7 @@ class ResourceLimitsServerNotices(object): user_id, content, EventTypes.Pinned, "" ) - async def _check_and_set_tags(self, user_id, room_id): + async def _check_and_set_tags(self, user_id: str, room_id: str) -> None: """ Since server notices rooms were originally not with tags, important to check that tags have been set correctly @@ -184,17 +182,16 @@ class ResourceLimitsServerNotices(object): ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) - async def _is_room_currently_blocked(self, room_id): + async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]: """ Determines if the room is currently blocked Args: - room_id(str): The room id of the server notices room + room_id: The room id of the server notices room Returns: - Deferred[Tuple[bool, List]]: bool: Is the room currently blocked - list: The list of pinned events that are unrelated to limit blocking + list: The list of pinned event IDs that are unrelated to limit blocking This list can be used as a convenience in the case where the block is to be lifted and the remaining pinned event references need to be preserved @@ -209,12 +206,12 @@ class ResourceLimitsServerNotices(object): # The user has yet to join the server notices room pass - referenced_events = [] + referenced_events = [] # type: List[str] if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) events = await self._store.get_events(referenced_events) - for event_id, event in iteritems(events): + for event_id, event in events.items(): if event.type != EventTypes.Message: continue if event.content.get("msgtype") == ServerNoticeMsgType: diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index bf2454c01c..0422d4c7ce 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from synapse.api.constants import EventTypes, Membership, RoomCreationPreset +from synapse.events import EventBase from synapse.types import UserID, create_requester from synapse.util.caches.descriptors import cached @@ -23,7 +25,7 @@ logger = logging.getLogger(__name__) SERVER_NOTICE_ROOM_TAG = "m.server_notice" -class ServerNoticesManager(object): +class ServerNoticesManager: def __init__(self, hs): """ @@ -50,20 +52,21 @@ class ServerNoticesManager(object): return self._config.server_notices_mxid is not None async def send_notice( - self, user_id, event_content, type=EventTypes.Message, state_key=None - ): + self, + user_id: str, + event_content: dict, + type: str = EventTypes.Message, + state_key: Optional[bool] = None, + ) -> EventBase: """Send a notice to the given user Creates the server notices room, if none exists. Args: - user_id (str): mxid of user to send event to. - event_content (dict): content of event to send - type(EventTypes): type of event - is_state_event(bool): Is the event a state event - - Returns: - Deferred[FrozenEvent] + user_id: mxid of user to send event to. + event_content: content of event to send + type: type of event + is_state_event: Is the event a state event """ room_id = await self.get_or_create_notice_room_for_user(user_id) await self.maybe_invite_user_to_room(user_id, room_id) @@ -89,17 +92,17 @@ class ServerNoticesManager(object): return event @cached() - async def get_or_create_notice_room_for_user(self, user_id): + async def get_or_create_notice_room_for_user(self, user_id: str) -> str: """Get the room for notices for a given user If we have not yet created a notice room for this user, create it, but don't invite the user to it. Args: - user_id (str): complete user id for the user we want a room for + user_id: complete user id for the user we want a room for Returns: - str: room id of notice room. + room id of notice room. """ if not self.is_enabled(): raise Exception("Server notices not enabled") @@ -163,7 +166,7 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id - async def maybe_invite_user_to_room(self, user_id: str, room_id: str): + async def maybe_invite_user_to_room(self, user_id: str, room_id: str) -> None: """Invite the given user to the given server room, unless the user has already joined or been invited to it. diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index be74e86641..6870b67ca0 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,13 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable, Union + from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) -class ServerNoticesSender(object): +class ServerNoticesSender: """A centralised place which sends server notices automatically when Certain Events take place """ @@ -32,22 +34,22 @@ class ServerNoticesSender(object): self._server_notices = ( ConsentServerNotices(hs), ResourceLimitsServerNotices(hs), - ) + ) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]] - async def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id: str) -> None: """Called when the user performs a sync operation. Args: - user_id (str): mxid of user who synced + user_id: mxid of user who synced """ for sn in self._server_notices: await sn.maybe_send_server_notice_to_user(user_id) - async def on_user_ip(self, user_id): + async def on_user_ip(self, user_id: str) -> None: """Called on the master when a worker process saw a client request. Args: - user_id (str): mxid + user_id: mxid """ # The synchrotrons use a stubbed version of ServerNoticesSender, so # we check for notices to send to the user in on_user_ip as well as diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py index 245ec7c64f..9273e61895 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer -class WorkerServerNoticesSender(object): +class WorkerServerNoticesSender: """Stub impl of ServerNoticesSender which does nothing""" def __init__(self, hs): @@ -24,24 +23,18 @@ class WorkerServerNoticesSender(object): hs (synapse.server.HomeServer): """ - def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id: str) -> None: """Called when the user performs a sync operation. Args: - user_id (str): mxid of user who synced - - Returns: - Deferred + user_id: mxid of user who synced """ - return defer.succeed(None) + return None - def on_user_ip(self, user_id): + async def on_user_ip(self, user_id: str) -> None: """Called on the master when a worker process saw a client request. Args: - user_id (str): mxid - - Returns: - Deferred + user_id: mxid """ raise AssertionError("on_user_ip unexpectedly called on worker") diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..395ac5ab02 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,7 +26,17 @@ if MYPY: logger = logging.getLogger(__name__) -class SpamCheckerApi(object): +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + +class SpamCheckerApi: """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. """ @@ -48,8 +59,10 @@ class SpamCheckerApi(object): twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: The filtered state events in the room. """ - state_ids = yield self._store.get_filtered_current_state_ids( - room_id=room_id, state_filter=StateFilter.from_types(types) + state_ids = yield defer.ensureDeferred( + self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 2fa529fcd0..c7e3015b5d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,15 +16,23 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set - -from six import iteritems, itervalues +from typing import ( + Awaitable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Union, + cast, + overload, +) import attr from frozendict import frozendict from prometheus_client import Histogram - -from twisted.internet import defer +from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -32,8 +40,10 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.types import StateMap +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.storage.roommember import ProfileInfo +from synapse.types import Collection, MutableStateMap, StateMap +from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -67,11 +77,17 @@ def _gen_state_id(): return s -class _StateCacheEntry(object): +class _StateCacheEntry: __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group, prev_group=None, delta_ids=None): - # dict[(str, str), str] map from (type, state_key) to event_id + def __init__( + self, + state: StateMap[str], + state_group: Optional[int], + prev_group: Optional[int] = None, + delta_ids: Optional[StateMap[str]] = None, + ): + # A map from (type, state_key) to event_id. self.state = frozendict(state) # the ID of a state group if one and only one is involved. @@ -97,7 +113,7 @@ class _StateCacheEntry(object): return len(self.state) -class StateHandler(object): +class StateHandler: """Fetches bits of state from the stores, and does state resolution where necessary """ @@ -109,114 +125,131 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def get_current_state( - self, room_id, event_type=None, state_key="", latest_event_ids=None - ): - """ Retrieves the current state for the room. This is done by + @overload + async def get_current_state( + self, + room_id: str, + event_type: Literal[None] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> StateMap[EventBase]: + ... + + @overload + async def get_current_state( + self, + room_id: str, + event_type: str, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Optional[EventBase]: + ... + + async def get_current_state( + self, + room_id: str, + event_type: Optional[str] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Union[Optional[EventBase], StateMap[EventBase]]: + """Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. This is equivalent to getting the state of an event that were to send next before receiving any new events. - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - Returns: - map from (type, state_key) to event + If `event_type` is specified, then the method returns only the one + event (or None) with that `event_type` and `state_key`. + + Otherwise, a map from (type, state_key) to event. """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: event_id = state.get((event_type, state_key)) event = None if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) return event - state_map = yield self.store.get_events( + state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) - state = { - key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map + return { + key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } - return state - - @defer.inlineCallbacks - def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids( + self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None + ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: - room_id (str): - - latest_event_ids (iterable[str]|None): if given, the forward - extremities to resolve. If None, we look them up from the - database (via a cache) + room_id: + latest_event_ids: if given, the forward extremities to resolve. If + None, we look them up from the database (via a cache). Returns: - Deferred[dict[(str, str), str)]]: the state dict, mapping from - (event_type, state_key) -> event_id + the state dict, mapping from (event_type, state_key) -> event_id """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + return ret.state - return state - - @defer.inlineCallbacks - def get_current_users_in_room(self, room_id, latest_event_ids=None): + async def get_current_users_in_room( + self, room_id: str, latest_event_ids: Optional[List[str]] = None + ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. Args: - room_id (str): The ID of the room. - latest_event_ids (List[str]|None): Precomputed list of latest - event IDs. Will be computed if None. + room_id: The ID of the room. + latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their - profileinfo. + Dictionary of user IDs to their profileinfo. """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None + logger.debug("calling resolve_state_groups from get_current_users_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state(room_id, entry) - return joined_users + entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + return await self.store.get_joined_users_from_state(room_id, entry) - @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id): - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + event_ids = await self.store.get_latest_event_ids_in_room(room_id) + return await self.get_hosts_in_room_at_events(room_id, event_ids) - @defer.inlineCallbacks - def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events( + self, room_id: str, event_ids: List[str] + ) -> Set[str]: """Get the hosts that were in a room at the given event ids Args: - room_id (str): - event_ids (list[str]): + room_id: + event_ids: Returns: - Deferred[list[str]]: the hosts in the room at the given events + The hosts in the room at the given events """ - entry = yield self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = yield self.store.get_joined_hosts(room_id, entry) - return joined_hosts + entry = await self.resolve_state_groups_for_events(room_id, event_ids) + return await self.store.get_joined_hosts(room_id, entry) - @defer.inlineCallbacks - def compute_event_context( + async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None - ): + ) -> EventContext: """Build an EventContext structure for the event. This works out what the current state should be for the event, and @@ -229,7 +262,7 @@ class StateHandler(object): when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. Returns: - synapse.events.snapshot.EventContext: + The event context. """ if event.internal_metadata.is_outlier(): @@ -270,7 +303,7 @@ class StateHandler(object): # if we're given the state before the event, then we use that state_ids_before_event = { (s.type, s.state_key): s.event_id for s in old_state - } + } # type: StateMap[str] state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -279,7 +312,7 @@ class StateHandler(object): # otherwise, we'll need to resolve the state across the prev_events. logger.debug("calling resolve_state_groups from compute_event_context") - entry = yield self.resolve_state_groups_for_events( + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -296,7 +329,7 @@ class StateHandler(object): # if not state_group_before_event: - state_group_before_event = yield self.state_store.store_state_group( + state_group_before_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -336,7 +369,7 @@ class StateHandler(object): state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = yield self.state_store.store_state_group( + state_group_after_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -354,27 +387,25 @@ class StateHandler(object): ) @measure_func() - @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events( + self, room_id: str, event_ids: Iterable[str] + ) -> _StateCacheEntry: """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: - room_id (str) - event_ids (list[str]) - explicit_room_version (str|None): If set uses the the given room - version to choose the resolution algorithm. If None, then - checks the database for room version. + room_id + event_ids Returns: - Deferred[_StateCacheEntry]: resolved state + The resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.state_store.get_state_groups_ids( + state_groups_ids = await self.state_store.get_state_groups_ids( room_id, event_ids ) @@ -383,7 +414,7 @@ class StateHandler(object): elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) + prev_group, delta_ids = await self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -392,9 +423,9 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) - result = yield self._state_resolution_handler.resolve_state_groups( + result = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, @@ -403,8 +434,12 @@ class StateHandler(object): ) return result - @defer.inlineCallbacks - def resolve_events(self, room_version, state_sets, event): + async def resolve_events( + self, + room_version: str, + state_sets: Collection[Iterable[EventBase]], + event: EventBase, + ) -> StateMap[EventBase]: logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -415,7 +450,8 @@ class StateHandler(object): state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( + self.clock, event.room_id, room_version, state_set_ids, @@ -423,12 +459,10 @@ class StateHandler(object): state_res_store=StateResolutionStore(self.store), ) - new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} - - return new_state + return {key: state_map[ev_id] for key, ev_id in new_state.items()} -class StateResolutionHandler(object): +class StateResolutionHandler: """Responsible for doing state conflict resolution. Note that the storage layer depends on this handler, so all functions must @@ -451,10 +485,14 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - @defer.inlineCallbacks @log_function - def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store + async def resolve_state_groups( + self, + room_id: str, + room_version: str, + state_groups_ids: Dict[int, StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", ): """Resolves conflicts between a set of state groups @@ -462,13 +500,13 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging and sanity checks) - room_version (str): version of the room - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group + room_id: room we are resolving for (used for logging and sanity checks) + room_version: version of the room + state_groups_ids: + A map from state group id to the state in that state group (where 'state' is a map from state key to event id) - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -476,16 +514,16 @@ class StateResolutionHandler(object): If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store Returns: - Deferred[_StateCacheEntry]: resolved state + The resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) - with (yield self.resolve_linearizer.queue(group_names)): + with (await self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: @@ -503,10 +541,10 @@ class StateResolutionHandler(object): # # XXX: is this actually worthwhile, or should we just let # resolve_events_with_store do it? - new_state = {} + new_state = {} # type: MutableStateMap[str] conflicted_state = False - for st in itervalues(state_groups_ids): - for key, e_id in iteritems(st): + for st in state_groups_ids.values(): + for key, e_id in st.items(): if key in new_state: conflicted_state = True break @@ -517,12 +555,20 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( - room_id, - room_version, - list(itervalues(state_groups_ids)), - event_map=event_map, - state_res_store=state_res_store, + # resolve_events_with_store returns a StateMap, but we can + # treat it as a MutableStateMap as it is above. It isn't + # actually mutated anymore (and is frozen in + # _make_state_cache_entry below). + new_state = cast( + MutableStateMap, + await resolve_events_with_store( + self.clock, + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ), ) # if the new state matches any of the input state groups, we can @@ -539,21 +585,22 @@ class StateResolutionHandler(object): return cache -def _make_state_cache_entry(new_state, state_groups_ids): +def _make_state_cache_entry( + new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] +) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. Args: - new_state (dict[(str, str), str]): resolved state map (mapping from - (type, state_key) to event_id) + new_state: resolved state map (mapping from (type, state_key) to event_id) - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group - (where 'state' is a map from state key to event id) + state_groups_ids: + map from state group id to the state in that state group (where + 'state' is a map from state key to event id) Returns: - _StateCacheEntry + The cache entry. """ # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -561,12 +608,12 @@ def _make_state_cache_entry(new_state, state_groups_ids): # not get persisted. # first look for exact matches - new_state_event_ids = set(itervalues(new_state)) - for sg, state in iteritems(state_groups_ids): + new_state_event_ids = set(new_state.values()) + for sg, state in state_groups_ids.items(): if len(new_state_event_ids) != len(state): continue - old_state_event_ids = set(itervalues(state)) + old_state_event_ids = set(state.values()) if new_state_event_ids == old_state_event_ids: # got an exact match. return _StateCacheEntry(state=new_state, state_group=sg) @@ -579,8 +626,8 @@ def _make_state_cache_entry(new_state, state_groups_ids): prev_group = None delta_ids = None - for old_group, old_state in iteritems(state_groups_ids): - n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} + for old_group, old_state in state_groups_ids.items(): + n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids @@ -591,12 +638,13 @@ def _make_state_cache_entry(new_state, state_groups_ids): def resolve_events_with_store( + clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", -): +) -> Awaitable[StateMap[str]]: """ Args: room_id: the room we are working in @@ -617,8 +665,7 @@ def resolve_events_with_store( state_res_store: a place to fetch events from Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: @@ -627,12 +674,12 @@ def resolve_events_with_store( ) else: return v2.resolve_events_with_store( - room_id, room_version, state_sets, event_map, state_res_store + clock, room_id, room_version, state_sets, event_map, state_res_store ) @attr.s -class StateResolutionStore(object): +class StateResolutionStore: """Interface that allows state resolution algorithms to access the database in well defined way. @@ -642,15 +689,17 @@ class StateResolutionStore(object): store = attr.ib() - def get_events(self, event_ids, allow_rejected=False): + def get_events( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - allow_rejected (bool): If True return rejected events. + event_ids: The event_ids of the events to fetch + allow_rejected: If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + An awaitable which resolves to a dict from event_id to event. """ return self.store.get_events( @@ -660,7 +709,9 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + def get_auth_chain_difference( + self, state_sets: List[Set[str]] + ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -669,7 +720,7 @@ class StateResolutionStore(object): chain. Returns: - Deferred[Set[str]]: Set of event IDs. + An awaitable that resolves to a set of event IDs. """ return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 9bf98d06f2..a493279cbd 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,18 +15,24 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional - -from six import iteritems, iterkeys, itervalues - -from twisted.internet import defer +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -34,13 +40,12 @@ logger = logging.getLogger(__name__) POWER_KEY = (EventTypes.PowerLevels, "") -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( room_id: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable, -): + state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], +) -> StateMap[str]: """ Args: room_id: the room we are working in @@ -58,11 +63,10 @@ def resolve_events_with_store( state_map_factory: will be called with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + an Awaitable that resolves to a dict of event_id to event. Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -70,19 +74,19 @@ def resolve_events_with_store( unconflicted_state, conflicted_state = _seperate(state_sets) needed_events = { - event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids + event_id for event_ids in conflicted_state.values() for event_id in event_ids } needed_event_count = len(needed_events) if event_map is not None: - needed_events -= set(iterkeys(event_map)) + needed_events -= set(event_map.keys()) logger.info( "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) + # A map from state event id to event. Only includes the state events which + # are in conflict (and those in event_map). + state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -96,23 +100,21 @@ def resolve_events_with_store( # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) - new_needed_events = set(itervalues(auth_events)) + new_needed_events = set(auth_events.values()) new_needed_event_count = len(new_needed_events) new_needed_events -= needed_events if event_map is not None: - new_needed_events -= set(iterkeys(event_map)) + new_needed_events -= set(event_map.keys()) logger.info( "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) - state_map_new = yield state_map_factory(new_needed_events) + state_map_new = await state_map_factory(new_needed_events) for event in state_map_new.values(): if event.room_id != room_id: raise Exception( @@ -127,32 +129,33 @@ def resolve_events_with_store( ) -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. Args: - state_sets(iterable[dict[(str, str), str]]): + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: + A tuple of (unconflicted_state, conflicted_state), where: - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} + conflicted_state = {} # type: MutableStateMap[Set[str]] for state_set in state_set_iterator: - for key, value in iteritems(state_set): + for key, value in state_set.items(): # Check if there is an unconflicted entry for the state key. unconflicted_value = unconflicted_state.get(key) if unconflicted_value is None: @@ -176,25 +179,42 @@ def _seperate(state_sets): return unconflicted_state, conflicted_state -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): +def _create_auth_events_from_maps( + unconflicted_state: StateMap[str], + conflicted_state: StateMap[Set[str]], + state_map: Dict[str, EventBase], +) -> StateMap[str]: + """ + + Args: + unconflicted_state: The unconflicted state map. + conflicted_state: The conflicted state map. + state_map: + + Returns: + A map from state key to event id. + """ auth_events = {} - for event_ids in itervalues(conflicted_state): + for event_ids in conflicted_state.values(): for event_id in event_ids: if event_id in state_map: keys = event_auth.auth_types_for_event(state_map[event_id]) for key in keys: if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id + auth_event_id = unconflicted_state.get(key, None) + if auth_event_id: + auth_events[key] = auth_event_id return auth_events def _resolve_with_state( - unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map + unconflicted_state_ids: MutableStateMap[str], + conflicted_state_ids: StateMap[Set[str]], + auth_event_ids: StateMap[str], + state_map: Dict[str, EventBase], ): conflicted_state = {} - for key, event_ids in iteritems(conflicted_state_ids): + for key, event_ids in conflicted_state_ids.items(): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] if len(events) > 1: conflicted_state[key] = events @@ -203,7 +223,7 @@ def _resolve_with_state( auth_events = { key: state_map[ev_id] - for key, ev_id in iteritems(auth_event_ids) + for key, ev_id in auth_event_ids.items() if ev_id in state_map } @@ -214,13 +234,15 @@ def _resolve_with_state( raise new_state = unconflicted_state_ids - for key, event in iteritems(resolved_state): + for key, event in resolved_state.items(): new_state[key] = event.event_id return new_state -def _resolve_state_events(conflicted_state, auth_events): +def _resolve_state_events( + conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] +) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. @@ -238,21 +260,21 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in iteritems(conflicted_state): + for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) - for key, events in iteritems(conflicted_state): + for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) - for key, events in iteritems(conflicted_state): + for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = _resolve_normal_events(events, auth_events) @@ -260,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events): return resolved_state -def _resolve_auth_events(events, auth_events): +def _resolve_auth_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { @@ -294,7 +318,9 @@ def _resolve_auth_events(events, auth_events): return event -def _resolve_normal_events(events, auth_events): +def _resolve_normal_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: for event in _ordered_events(events): try: # The signatures have already been checked at this point @@ -314,7 +340,7 @@ def _resolve_normal_events(events, auth_events): return event -def _ordered_events(events): +def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]: def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 18484e2fa6..edf94e7ad6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,11 +16,21 @@ import heapq import itertools import logging -from typing import Dict, List, Optional - -from six import iteritems, itervalues - -from twisted.internet import defer +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + overload, +) + +from typing_extensions import Literal import synapse.state from synapse import event_auth @@ -28,29 +38,34 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap +from synapse.util import Clock logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def resolve_events_with_store( +# We want to await to the reactor occasionally during state res when dealing +# with large data sets, so that we don't exhaust the reactor. This is done by +# awaiting to reactor during loops every N iterations. +_AWAIT_AFTER_ITERATIONS = 100 + + +async def resolve_events_with_store( + clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", -): +) -> StateMap[str]: """Resolves the state using the v2 state resolution algorithm Args: + clock room_id: the room we are working in - room_version: The room version - state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be @@ -62,8 +77,7 @@ def resolve_events_with_store( state_res_store: Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ logger.debug("Computing conflicted state") @@ -83,15 +97,15 @@ def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) full_conflicted_set = set( itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + itertools.chain.from_iterable(conflicted_state.values()), auth_diff ) ) - events = yield state_res_store.get_events( + events = await state_res_store.get_events( [eid for eid in full_conflicted_set if eid not in event_map], allow_rejected=True, ) @@ -114,14 +128,15 @@ def resolve_events_with_store( eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) - sorted_power_events = yield _reverse_topological_power_sort( - room_id, power_events, event_map, state_res_store, full_conflicted_set + sorted_power_events = await _reverse_topological_power_sort( + clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( + clock, room_id, room_version, sorted_power_events, @@ -135,20 +150,22 @@ def resolve_events_with_store( # OK, so we've now resolved the power events. Now sort the remaining # events using the mainline of the resolved power level. + set_power_events = set(sorted_power_events) leftover_events = [ - ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events + ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events ] logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) - leftover_events = yield _mainline_sort( - room_id, leftover_events, pl, event_map, state_res_store + leftover_events = await _mainline_sort( + clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( + clock, room_id, room_version, leftover_events, @@ -167,25 +184,29 @@ def resolve_events_with_store( return resolved_state -@defer.inlineCallbacks -def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Return the power level of the sender of the given event according to their auth events. Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + room_id + event_id + event_map + state_res_store Returns: - Deferred[int] + The power level. """ - event = yield _get_event(room_id, event_id, event_map, state_res_store) + event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -195,7 +216,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): @@ -214,38 +235,43 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): return int(level) -@defer.inlineCallbacks -def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference( + state_sets: Sequence[StateMap[str]], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. Args: - state_sets (list) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + state_sets + event_map + state_res_store Returns: - Deferred[set[str]]: Set of event IDs + Set of event IDs """ - difference = yield state_res_store.get_auth_chain_difference( + difference = await state_res_store.get_auth_chain_difference( [set(state_set.values()) for state_set in state_sets] ) return difference -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Return the unconflicted and conflicted state. This is different than in the original algorithm, as this defines a key to be conflicted if one of the state sets doesn't have that key. Args: - state_sets (list) + state_sets Returns: - tuple[dict, dict]: A tuple of unconflicted and conflicted state. The - conflicted state dict is a map from type/state_key to set of event IDs + A tuple of unconflicted and conflicted state. The conflicted state dict + is a map from type/state_key to set of event IDs """ unconflicted_state = {} conflicted_state = {} @@ -258,18 +284,20 @@ def _seperate(state_sets): event_ids.discard(None) conflicted_state[key] = event_ids - return unconflicted_state, conflicted_state + # mypy doesn't understand that discarding None above means that conflicted + # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]]. + return unconflicted_state, conflicted_state # type: ignore -def _is_power_event(event): +def _is_power_event(event: EventBase) -> bool: """Return whether or not the event is a "power event", as defined by the v2 state resolution algorithm Args: - event (FrozenEvent) + event Returns: - boolean + True if the event is a power event. """ if (event.type, event.state_key) in ( (EventTypes.PowerLevels, ""), @@ -285,21 +313,24 @@ def _is_power_event(event): return False -@defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph( - graph, room_id, event_id, event_map, state_res_store, auth_diff -): +async def _add_event_and_auth_chain_to_graph( + graph: Dict[str, Set[str]], + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> None: """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph Args: - graph (dict[str, set[str]]): A map from event ID to the events auth - event IDs - room_id (str): the room we are working in - event_id (str): Event to add to the graph - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + graph: A map from event ID to the events auth event IDs + room_id: the room we are working in + event_id: Event to add to the graph + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. """ state = [event_id] @@ -307,7 +338,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(room_id, eid, event_map, state_res_store) + event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -316,37 +347,52 @@ def _add_event_and_auth_chain_to_graph( graph.setdefault(eid, set()).add(aid) -@defer.inlineCallbacks -def _reverse_topological_power_sort( - room_id, event_ids, event_map, state_res_store, auth_diff -): +async def _reverse_topological_power_sort( + clock: Clock, + room_id: str, + event_ids: Iterable[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> List[str]: """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: - room_id (str): the room we are working in - event_ids (list[str]): The events to sort - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + clock + room_id: the room we are working in + event_ids: The events to sort + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. Returns: - Deferred[list[str]]: The sorted list + The sorted list """ - graph = {} - for event_id in event_ids: - yield _add_event_and_auth_chain_to_graph( + graph = {} # type: Dict[str, Set[str]] + for idx, event_id in enumerate(event_ids, start=1): + await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) + # We await occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + event_to_pl = {} - for event_id in graph: - pl = yield _get_power_level_for_sender( + for idx, event_id in enumerate(graph, start=1): + pl = await _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl + # We await occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + def _get_power_order(event_id): ev = event_map[event_id] pl = event_to_pl[event_id] @@ -360,33 +406,39 @@ def _reverse_topological_power_sort( return sorted_events -@defer.inlineCallbacks -def _iterative_auth_checks( - room_id, room_version, event_ids, base_state, event_map, state_res_store -): +async def _iterative_auth_checks( + clock: Clock, + room_id: str, + room_version: str, + event_ids: List[str], + base_state: StateMap[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> MutableStateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: - room_id (str) - room_version (str) - event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (StateMap[str]): The set of state to start with - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id + room_version + event_ids: Ordered list of events to apply auth checks to + base_state: The set of state to start with + event_map + state_res_store Returns: - Deferred[StateMap[str]]: Returns the final updated state + Returns the final updated state """ - resolved_state = base_state.copy() + resolved_state = dict(base_state) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - for event_id in event_ids: + for idx, event_id in enumerate(event_ids, start=1): event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) @@ -401,7 +453,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(room_id, ev_id, event_map, state_res_store) + ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -419,114 +471,173 @@ def _iterative_auth_checks( except AuthError: pass + # We await occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + return resolved_state -@defer.inlineCallbacks -def _mainline_sort( - room_id, event_ids, resolved_power_event_id, event_map, state_res_store -): +async def _mainline_sort( + clock: Clock, + room_id: str, + event_ids: List[str], + resolved_power_event_id: Optional[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> List[str]: """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: - room_id (str): room we're working in - event_ids (list[str]): Events to sort - resolved_power_event_id (str): The final resolved power level event ID - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id: room we're working in + event_ids: Events to sort + resolved_power_event_id: The final resolved power level event ID + event_map + state_res_store Returns: - Deferred[list[str]]: The sorted list + The sorted list """ + if not event_ids: + # It's possible for there to be no event IDs here to sort, so we can + # skip calculating the mainline in that case. + return [] + mainline = [] pl = resolved_power_event_id + idx = 0 while pl: mainline.append(pl) - pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) + pl_ev = await _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break + # We await occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + + idx += 1 + mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))} event_ids = list(event_ids) order_map = {} - for ev_id in event_ids: - depth = yield _get_mainline_depth_for_event( + for idx, ev_id in enumerate(event_ids, start=1): + depth = await _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) + # We await occasionally when we're working with large data sets to + # ensure that we don't block the reactor loop for too long. + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids -@defer.inlineCallbacks -def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): +async def _get_mainline_depth_for_event( + event: EventBase, + mainline_map: Dict[str, int], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Get the mainline depths for the given event based on the mainline map Args: - event (FrozenEvent) - mainline_map (dict[str, int]): Map from event_id to mainline depth for - events in the mainline. - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + event + mainline_map: Map from event_id to mainline depth for events in the mainline. + event_map + state_res_store Returns: - Deferred[int] + The mainline depth """ room_id = event.room_id + tmp_event = event # type: Optional[EventBase] # We do an iterative search, replacing `event with the power level in its # auth events (if any) - while event: + while tmp_event: depth = mainline_map.get(event.event_id) if depth is not None: return depth - auth_events = event.auth_event_ids() - event = None + auth_events = tmp_event.auth_event_ids() + tmp_event = None for aid in auth_events: - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): - event = aev + tmp_event = aev break # Didn't find a power level auth event, so we just return 0 return 0 -@defer.inlineCallbacks -def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[False] = False, +) -> EventBase: + ... + + +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[True], +) -> Optional[EventBase]: + ... + + +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: bool = False, +) -> Optional[EventBase]: """Helper function to look up event in event_map, falling back to looking it up in the store Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - allow_none (bool): if the event is not found, return None rather than raising + room_id + event_id + event_map + state_res_store + allow_none: if the event is not found, return None rather than raising an exception Returns: - Deferred[Optional[FrozenEvent]] + The event, or none if the event does not exist (and allow_none is True). """ if event_id not in event_map: - events = yield state_res_store.get_events([event_id], allow_rejected=True) + events = await state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) event = event_map.get(event_id) @@ -543,7 +654,9 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): return event -def lexicographical_topological_sort(graph, key): +def lexicographical_topological_sort( + graph: Dict[str, Set[str]], key: Callable[[str], Any] +) -> Generator[str, None, None]: """Performs a lexicographic reverse topological sort on the graph. This returns a reverse topological sort (i.e. if node A references B then B @@ -553,26 +666,26 @@ def lexicographical_topological_sort(graph, key): NOTE: `graph` is modified during the sort. Args: - graph (dict[str, set[str]]): A representation of the graph where each - node is a key in the dict and its value are the nodes edges. - key (func): A function that takes a node and returns a value that is - comparable and used to order nodes + graph: A representation of the graph where each node is a key in the + dict and its value are the nodes edges. + key: A function that takes a node and returns a value that is comparable + and used to order nodes Yields: - str: The next node in the topological sort + The next node in the topological sort """ # Note, this is basically Kahn's algorithm except we look at nodes with no # outgoing edges, c.f. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm outdegree_map = graph - reverse_graph = {} + reverse_graph = {} # type: Dict[str, Set[str]] # Lists of nodes with zero out degree. Is actually a tuple of # `(key(node), node)` so that sorting does the right thing zero_outdegree = [] - for node, edges in iteritems(graph): + for node, edges in graph.items(): if len(edges) == 0: zero_outdegree.append((key(node), node)) diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index 6fefdaaff7..9e6daf38ac 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -1,24 +1,24 @@ <!doctype html> <html> <head> -<title> Login </title> -<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<link rel="stylesheet" href="style.css"> -<script src="js/jquery-3.4.1.min.js"></script> -<script src="js/login.js"></script> + <meta http-equiv="Content-Type" content="text/html; charset=utf-8"> + <title> Login </title> + <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> + <link rel="stylesheet" href="style.css"> + <script src="js/jquery-3.4.1.min.js"></script> + <script src="js/login.js"></script> </head> <body onload="matrixLogin.onLoad()"> - <center> - <br/> + <div id="container"> <h1 id="title"></h1> - <span id="feedback" style="color: #f00"></span> + <span id="feedback"></span> <div id="loading"> <img src="spinner.gif" /> </div> - <div id="sso_flow" class="login_flow" style="display:none"> + <div id="sso_flow" class="login_flow" style="display: none;"> Single-sign on: <form id="sso_form" action="/_matrix/client/r0/login/sso/redirect" method="get"> <input id="sso_redirect_url" type="hidden" name="redirectUrl" value=""/> @@ -26,9 +26,9 @@ </form> </div> - <div id="password_flow" class="login_flow" style="display:none"> + <div id="password_flow" class="login_flow" style="display: none;"> Password Authentication: - <form onsubmit="matrixLogin.password_login(); return false;"> + <form onsubmit="matrixLogin.passwordLogin(); return false;"> <input id="user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" /> <br/> <input id="password" size="32" type="password" placeholder="Password"/> @@ -38,9 +38,9 @@ </form> </div> - <div id="no_login_types" type="button" class="login_flow" style="display:none"> + <div id="no_login_types" type="button" class="login_flow" style="display: none;"> Log in currently unavailable. </div> - </center> + </div> </body> </html> diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index ba8048b23f..3678670ec7 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -5,11 +5,11 @@ window.matrixLogin = { }; // Titles get updated through the process to give users feedback. -var TITLE_PRE_AUTH = "Log in with one of the following methods"; -var TITLE_POST_AUTH = "Logging in..."; +const TITLE_PRE_AUTH = "Log in with one of the following methods"; +const TITLE_POST_AUTH = "Logging in..."; // The cookie used to store the original query parameters when using SSO. -var COOKIE_KEY = "synapse_login_fallback_qs"; +const COOKIE_KEY = "synapse_login_fallback_qs"; /* * Submit a login request. @@ -20,9 +20,9 @@ var COOKIE_KEY = "synapse_login_fallback_qs"; * login request, e.g. device_id. * callback: (Optional) Function to call on successful login. */ -var submitLogin = function(type, data, extra, callback) { +function submitLogin(type, data, extra, callback) { console.log("Logging in with " + type); - set_title(TITLE_POST_AUTH); + setTitle(TITLE_POST_AUTH); // Add the login type. data.type = type; @@ -41,12 +41,15 @@ var submitLogin = function(type, data, extra, callback) { } matrixLogin.onLogin(response); }).fail(errorFunc); -}; +} -var errorFunc = function(err) { +/* + * Display an error to the user and show the login form again. + */ +function errorFunc(err) { // We want to show the error to the user rather than redirecting immediately to the // SSO portal (if SSO is the only login option), so we inhibit the redirect. - show_login(true); + showLogin(true); if (err.responseJSON && err.responseJSON.error) { setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")"); @@ -54,27 +57,42 @@ var errorFunc = function(err) { else { setFeedbackString("Request failed: " + err.status); } -}; +} -var setFeedbackString = function(text) { +/* + * Display an error to the user. + */ +function setFeedbackString(text) { $("#feedback").text(text); -}; +} -var show_login = function(inhibit_redirect) { - // Set the redirect to come back to this page, a login token will get added - // and handled after the redirect. - var this_page = window.location.origin + window.location.pathname; - $("#sso_redirect_url").val(this_page); +/* + * (Maybe) Show the login forms. + * + * This actually does a few unrelated functions: + * + * * Configures the SSO redirect URL to come back to this page. + * * Configures and shows the SSO form, if the server supports SSO. + * * Otherwise, shows the password form. + */ +function showLogin(inhibitRedirect) { + setTitle(TITLE_PRE_AUTH); - // If inhibit_redirect is false, and SSO is the only supported login method, + // If inhibitRedirect is false, and SSO is the only supported login method, // we can redirect straight to the SSO page. if (matrixLogin.serverAcceptsSso) { + // Set the redirect to come back to this page, a login token will get + // added as a query parameter and handled after the redirect. + $("#sso_redirect_url").val(window.location.origin + window.location.pathname); + // Before submitting SSO, set the current query parameters into a cookie // for retrieval later. var qs = parseQsFromUrl(); setCookie(COOKIE_KEY, JSON.stringify(qs)); - if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { + // If password is not supported and redirects are allowed, then submit + // the form (redirecting to the SSO provider). + if (!inhibitRedirect && !matrixLogin.serverAcceptsPassword) { $("#sso_form").submit(); return; } @@ -87,30 +105,39 @@ var show_login = function(inhibit_redirect) { $("#password_flow").show(); } + // If neither password or SSO are supported, show an error to the user. if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) { $("#no_login_types").show(); } - set_title(TITLE_PRE_AUTH); - $("#loading").hide(); -}; +} -var show_spinner = function() { +/* + * Hides the forms and shows a loading throbber. + */ +function showSpinner() { $("#password_flow").hide(); $("#sso_flow").hide(); $("#no_login_types").hide(); $("#loading").show(); -}; +} -var set_title = function(title) { +/* + * Helper to show the page's main title. + */ +function setTitle(title) { $("#title").text(title); -}; +} -var fetch_info = function(cb) { +/* + * Query the login endpoint for the homeserver's supported flows. + * + * This populates matrixLogin.serverAccepts* variables. + */ +function fetchLoginFlows(cb) { $.get(matrixLogin.endpoint, function(response) { - var serverAcceptsPassword = false; - for (var i=0; i<response.flows.length; i++) { + for (var i = 0; i < response.flows.length; i++) { var flow = response.flows[i]; if ("m.login.sso" === flow.type) { matrixLogin.serverAcceptsSso = true; @@ -126,27 +153,41 @@ var fetch_info = function(cb) { }).fail(errorFunc); } +/* + * Called on load to fetch login flows and attempt SSO login (if a token is available). + */ matrixLogin.onLoad = function() { - fetch_info(function() { - if (!try_token()) { - show_login(false); + fetchLoginFlows(function() { + // (Maybe) attempt logging in via SSO if a token is available. + if (!tryTokenLogin()) { + showLogin(false); } }); }; -matrixLogin.password_login = function() { +/* + * Submit simple user & password login. + */ +matrixLogin.passwordLogin = function() { var user = $("#user_id").val(); var pwd = $("#password").val(); setFeedbackString(""); - show_spinner(); + showSpinner(); submitLogin( "m.login.password", {user: user, password: pwd}, parseQsFromUrl()); }; +/* + * The onLogin function gets called after a succesful login. + * + * It is expected that implementations override this to be notified when the + * login is complete. The response to the login call is provided as the single + * parameter. + */ matrixLogin.onLogin = function(response) { // clobber this function console.warn("onLogin - This function should be replaced to proceed."); @@ -155,7 +196,7 @@ matrixLogin.onLogin = function(response) { /* * Process the query parameters from the current URL into an object. */ -var parseQsFromUrl = function() { +function parseQsFromUrl() { var pos = window.location.href.indexOf("?"); if (pos == -1) { return {}; @@ -174,12 +215,12 @@ var parseQsFromUrl = function() { result[key] = val; }); return result; -}; +} /* * Process the cookies and return an object. */ -var parseCookies = function() { +function parseCookies() { var allCookies = document.cookie; var result = {}; allCookies.split(";").forEach(function(part) { @@ -196,32 +237,32 @@ var parseCookies = function() { result[key] = val; }); return result; -}; +} /* * Set a cookie that is valid for 1 hour. */ -var setCookie = function(key, value) { +function setCookie(key, value) { // The maximum age is set in seconds. var maxAge = 60 * 60; // Set the cookie, this defaults to the current domain and path. document.cookie = key + "=" + encodeURIComponent(value) + ";max-age=" + maxAge + ";sameSite=lax"; -}; +} /* * Removes a cookie by key. */ -var deleteCookie = function(key) { +function deleteCookie(key) { // Delete a cookie by setting the expiration to 0. (Note that the value // doesn't matter.) document.cookie = key + "=deleted;expires=0"; -}; +} /* * Submits the login token if one is found in the query parameters. Returns a * boolean of whether the login token was found or not. */ -var try_token = function() { +function tryTokenLogin() { // Check if the login token is in the query parameters. var qs = parseQsFromUrl(); @@ -233,18 +274,18 @@ var try_token = function() { // Retrieve the original query parameters (from before the SSO redirect). // They are stored as JSON in a cookie. var cookies = parseCookies(); - var original_query_params = JSON.parse(cookies[COOKIE_KEY] || "{}") + var originalQueryParams = JSON.parse(cookies[COOKIE_KEY] || "{}") // If the login is successful, delete the cookie. - var callback = function() { + function callback() { deleteCookie(COOKIE_KEY); } submitLogin( "m.login.token", {token: loginToken}, - original_query_params, + originalQueryParams, callback); return true; -}; +} diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css index 1cce5ed950..83e4f6abc8 100644 --- a/synapse/static/client/login/style.css +++ b/synapse/static/client/login/style.css @@ -31,20 +31,44 @@ form { margin: 10px 0 0 0; } +/* + * Add some padding to the viewport. + */ +#container { + padding: 10px; +} +/* + * Center all direct children of the main form. + */ +#container > * { + display: block; + margin-left: auto; + margin-right: auto; + text-align: center; +} + +/* + * A wrapper around each login flow. + */ .login_flow { width: 300px; text-align: left; padding: 10px; margin-bottom: 40px; - -webkit-border-radius: 10px; - -moz-border-radius: 10px; border-radius: 10px; - - -webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15); - -moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15); box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15); background-color: #f8f8f8; border: 1px #ccc solid; } + +/* + * Used to show error content. + */ +#feedback { + /* Red text. */ + color: #ff0000; + /* A little space to not overlap the box-shadow. */ + margin-bottom: 20px; +} diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ec89f645d4..8e5d78f6f7 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -17,18 +17,19 @@ """ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple -databases). The `Database` class represents a single physical database. The -`data_stores` are classes that talk directly to a `Database` instance and have -associated schemas, background updates, etc. On top of those there are classes -that provide high level interfaces that combine calls to multiple `data_stores`. +databases). The `DatabasePool` class represents connections to a single physical +database. The `databases` are classes that talk directly to a `DatabasePool` +instance and have associated schemas, background updates, etc. On top of those +there are classes that provide high level interfaces that combine calls to +multiple `databases`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from synapse.storage.data_stores import DataStores -from synapse.storage.data_stores.main import DataStore +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage @@ -36,11 +37,11 @@ from synapse.storage.state import StateGroupStorage __all__ = ["DataStores", "DataStore"] -class Storage(object): +class Storage: """The high level interfaces for talking to various storage layers. """ - def __init__(self, hs, stores: DataStores): + def __init__(self, hs, stores: Databases): # We include the main data store here mainly so that we don't have to # rewrite all the existing code to split it into high vs low level # interfaces. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index bfce541ca7..ab49d227de 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,12 +19,11 @@ import random from abc import ABCMeta from typing import Any, Optional -from canonicaljson import json - from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.types import Collection, get_domain_from_id +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -37,11 +36,11 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine - self.db = database + self.db_pool = database self.rand = random.SystemRandom() def process_replication_rows(self, stream_name, instance_name, token, rows): @@ -58,7 +57,6 @@ class SQLBaseStore(metaclass=ABCMeta): """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) - self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) @@ -100,13 +98,13 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, so we - # consistenty get a Unicode-containing object out. + # Decode it to a Unicode string before feeding it to the JSON decoder, since + # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") try: - return json.loads(db_content) + return json_decoder.decode(db_content) except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 59f3394b0a..810721ebe9 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -16,18 +16,15 @@ import logging from typing import Optional -from canonicaljson import json - -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import json_encoder from . import engines logger = logging.getLogger(__name__) -class BackgroundUpdatePerformance(object): +class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" def __init__(self, name): @@ -74,7 +71,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdater(object): +class BackgroundUpdater: """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -88,7 +85,7 @@ class BackgroundUpdater(object): def __init__(self, hs, database): self._clock = hs.get_clock() - self.db = database + self.db_pool = database # if a background update is currently running, its name. self._current_background_update = None # type: Optional[str] @@ -139,7 +136,7 @@ class BackgroundUpdater(object): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = await self.db.simple_select_onecol( + updates = await self.db_pool.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -160,7 +157,7 @@ class BackgroundUpdater(object): if update_name == self._current_background_update: return False - update_exists = await self.db.simple_select_one_onecol( + update_exists = await self.db_pool.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -189,10 +186,10 @@ class BackgroundUpdater(object): ORDER BY ordering, update_name """ ) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) if not self._current_background_update: - all_pending_updates = await self.db.runInteraction( + all_pending_updates = await self.db_pool.runInteraction( "background_updates", get_background_updates_txn, ) if not all_pending_updates: @@ -243,13 +240,16 @@ class BackgroundUpdater(object): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = await self.db.simple_select_one_onecol( + progress_json = await self.db_pool.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", ) - progress = json.loads(progress_json) + # Avoid a circular import. + from synapse.storage._base import db_to_json + + progress = db_to_json(progress_json) time_start = self._clock.time_msec() items_updated = await update_handler(progress, batch_size) @@ -305,9 +305,8 @@ class BackgroundUpdater(object): update_name (str): Name of update """ - @defer.inlineCallbacks - def noop_update(progress, batch_size): - yield self._end_background_update(update_name) + async def noop_update(progress, batch_size): + await self._end_background_update(update_name) return 1 self.register_background_update_handler(update_name, noop_update) @@ -399,30 +398,30 @@ class BackgroundUpdater(object): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.db.engine, engines.PostgresEngine): + if isinstance(self.db_pool.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None else: runner = create_index_sqlite - @defer.inlineCallbacks - def updater(progress, batch_size): + async def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.db.runWithConnection(runner) - yield self._end_background_update(update_name) + await self.db_pool.runWithConnection(runner) + await self._end_background_update(update_name) return 1 self.register_background_update_handler(update_name, updater) - def _end_background_update(self, update_name): + async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. Args: - update_name(str): The name of the completed task to remove + update_name:: The name of the completed task to remove + Returns: - A deferred that completes once the task is removed. + None, completes once the task is removed. """ if update_name != self._current_background_update: raise Exception( @@ -430,11 +429,11 @@ class BackgroundUpdater(object): % update_name ) self._current_background_update = None - return self.db.simple_delete_one( + await self.db_pool.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) - def _background_update_progress(self, update_name: str, progress: dict): + async def _background_update_progress(self, update_name: str, progress: dict): """Update the progress of a background update Args: @@ -442,7 +441,7 @@ class BackgroundUpdater(object): progress: The progress of the update. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, @@ -458,9 +457,9 @@ class BackgroundUpdater(object): progress(dict): The progress of the update. """ - progress_json = json.dumps(progress) + progress_json = json_encoder.encode(progress) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py deleted file mode 100644 index 599ee470d4..0000000000 --- a/synapse/storage/data_stores/__init__.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.storage.data_stores.main.events import PersistEventsStore -from synapse.storage.data_stores.state import StateGroupDataStore -from synapse.storage.database import Database, make_conn -from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database - -logger = logging.getLogger(__name__) - - -class DataStores(object): - """The various data stores. - - These are low level interfaces to physical databases. - - Attributes: - main (DataStore) - """ - - def __init__(self, main_store_class, hs): - # Note we pass in the main store class here as workers use a different main - # store. - - self.databases = [] - self.main = None - self.state = None - self.persist_events = None - - for database_config in hs.config.database.databases: - db_name = database_config.name - engine = create_engine(database_config.config) - - with make_conn(database_config, engine) as db_conn: - logger.info("Preparing database %r...", db_name) - - engine.check_database(db_conn) - prepare_database( - db_conn, engine, hs.config, data_stores=database_config.data_stores, - ) - - database = Database(hs, database_config, engine) - - if "main" in database_config.data_stores: - logger.info("Starting 'main' data store") - - # Sanity check we don't try and configure the main store on - # multiple databases. - if self.main: - raise Exception("'main' data store already configured") - - self.main = main_store_class(database, db_conn, hs) - - # If we're on a process that can persist events also - # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): - self.persist_events = PersistEventsStore( - hs, database, self.main - ) - - if "state" in database_config.data_stores: - logger.info("Starting 'state' data store") - - # Sanity check we don't try and configure the state store on - # multiple databases. - if self.state: - raise Exception("'state' data store already configured") - - self.state = StateGroupDataStore(database, db_conn, hs) - - db_conn.commit() - - self.databases.append(database) - - logger.info("Database %r prepared", db_name) - - # Sanity check that we have actually configured all the required stores. - if not self.main: - raise Exception("No 'main' data store configured") - - if not self.state: - raise Exception("No 'main' data store configured") diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b112ff3df2..ed8a9bffb1 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import time +from sys import intern from time import monotonic as monotonic_time from typing import ( Any, @@ -27,15 +28,14 @@ from typing import ( Optional, Tuple, TypeVar, + cast, + overload, ) -from six import iteritems, iterkeys, itervalues -from six.moves import intern, range - from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi -from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -51,11 +51,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E from synapse.storage.types import Connection, Cursor from synapse.types import Collection -logger = logging.getLogger(__name__) - # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 +logger = logging.getLogger(__name__) + sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") perf_logger = logging.getLogger("synapse.storage.TIME") @@ -127,7 +127,7 @@ class LoggingTransaction: method. Args: - txn: The database transcation object to wrap. + txn: The database transaction object to wrap. name: The name of this transactions for logging. database_engine after_callbacks: A list that callbacks will be appended to @@ -162,7 +162,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: "Callable[..., None]", *args, **kwargs): + def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -173,7 +173,9 @@ class LoggingTransaction: assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + def call_on_exception( + self, callback: "Callable[..., None]", *args: Any, **kwargs: Any + ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. @@ -197,7 +199,7 @@ class LoggingTransaction: def description(self) -> Any: return self.txn.description - def execute_batch(self, sql, args): + def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch # type: ignore @@ -206,17 +208,17 @@ class LoggingTransaction: for val in args: self.execute(sql, val) - def execute(self, sql: str, *args: Any): + def execute(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.execute, sql, *args) - def executemany(self, sql: str, *args: Any): + def executemany(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.executemany, sql, *args) def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) - def _do_execute(self, func, sql, *args): + def _do_execute(self, func, sql: str, *args: Any) -> None: sql = self._make_sql_one_line(sql) # TODO(paul): Maybe use 'info' and 'debug' for values? @@ -235,31 +237,31 @@ class LoggingTransaction: try: return func(sql, *args) except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) + sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise finally: secs = time.time() - start sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_query_timer.labels(sql.split()[0]).observe(secs) - def close(self): + def close(self) -> None: self.txn.close() -class PerformanceCounters(object): +class PerformanceCounters: def __init__(self): self.current_counters = {} self.previous_counters = {} - def update(self, key, duration_secs): + def update(self, key: str, duration_secs: float) -> None: count, cum_time = self.current_counters.get(key, (0, 0)) count += 1 cum_time += duration_secs self.current_counters[key] = (count, cum_time) - def interval(self, interval_duration_secs, limit=3): + def interval(self, interval_duration_secs: float, limit: int = 3) -> str: counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): + for name, (count, cum_time) in self.current_counters.items(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append( ( @@ -281,7 +283,10 @@ class PerformanceCounters(object): return top_n_counters -class Database(object): +R = TypeVar("R") + + +class DatabasePool: """Wraps a single physical database and connection pool. A single database may be used by multiple data stores. @@ -329,13 +334,12 @@ class Database(object): self._check_safe_to_upsert, ) - def is_running(self): + def is_running(self) -> bool: """Is the database pool currently running """ return self._db_pool.running - @defer.inlineCallbacks - def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self) -> None: """ Is it safe to use native UPSERT? @@ -344,7 +348,7 @@ class Database(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self.simple_select_list( + updates = await self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -366,7 +370,7 @@ class Database(object): self._check_safe_to_upsert, ) - def start_profiling(self): + def start_profiling(self) -> None: self._previous_loop_ts = monotonic_time() def loop(): @@ -390,8 +394,15 @@ class Database(object): self._clock.looping_call(loop, 10000) def new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): + self, + conn: Connection, + desc: str, + after_callbacks: List[_CallbackListEntry], + exception_callbacks: List[_CallbackListEntry], + func: "Callable[..., R]", + *args: Any, + **kwargs: Any + ) -> R: start = monotonic_time() txn_id = self._TXN_ID @@ -421,7 +432,7 @@ class Database(object): except self.engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. - logger.warning( + transaction_logger.warning( "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: @@ -429,18 +440,20 @@ class Database(object): try: conn.rollback() except self.engine.module.Error as e1: - logger.warning("[TXN EROLL] {%s} %s", name, e1) + transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: if self.engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + transaction_logger.warning( + "[TXN DEADLOCK] {%s} %d/%d", name, i, N + ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( + transaction_logger.warning( "[TXN EROLL] {%s} %s", name, e1, ) continue @@ -480,7 +493,7 @@ class Database(object): # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 cursor.close() except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) + transaction_logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: end = monotonic_time() @@ -494,8 +507,9 @@ class Database(object): self._txn_perf_counters.update(desc, duration) sql_txn_timer.labels(desc).observe(duration) - @defer.inlineCallbacks - def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): + async def runInteraction( + self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + ) -> R: """Starts a transaction on the database and runs a given function Arguments: @@ -508,7 +522,7 @@ class Database(object): kwargs: named args to pass to `func` Returns: - Deferred: The result of func + The result of func """ after_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry] @@ -517,7 +531,7 @@ class Database(object): logger.warning("Starting db txn '%s' from sentinel context", desc) try: - result = yield self.runWithConnection( + result = await self.runWithConnection( self.new_transaction, desc, after_callbacks, @@ -534,10 +548,11 @@ class Database(object): after_callback(*after_args, **after_kwargs) raise - return result + return cast(R, result) - @defer.inlineCallbacks - def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): + async def runWithConnection( + self, func: "Callable[..., R]", *args: Any, **kwargs: Any + ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: @@ -548,7 +563,7 @@ class Database(object): kwargs: named args to pass to `func` Returns: - Deferred: The result of func + The result of func """ parent_context = current_context() # type: Optional[LoggingContextOrSentinel] if not parent_context: @@ -571,18 +586,16 @@ class Database(object): return func(conn, *args, **kwargs) - result = yield make_deferred_yieldable( + return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) ) - return result - @staticmethod - def cursor_to_dict(cursor): + def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: """Converts a SQL cursor into an list of dicts. Args: - cursor : The DBAPI cursor which has executed a query. + cursor: The DBAPI cursor which has executed a query. Returns: A list of dicts where the key is the column header. """ @@ -590,10 +603,29 @@ class Database(object): results = [dict(zip(col_headers, row)) for row in cursor] return results - def execute(self, desc, decoder, query, *args): + @overload + async def execute( + self, desc: str, decoder: Literal[None], query: str, *args: Any + ) -> List[Tuple[Any, ...]]: + ... + + @overload + async def execute( + self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any + ) -> R: + ... + + async def execute( + self, + desc: str, + decoder: Optional[Callable[[Cursor], R]], + query: str, + *args: Any + ) -> R: """Runs a single query for a result set. Args: + desc: description of the transaction, for logging and metrics decoder - The function which can resolve the cursor results to something meaningful. query - The query string to execute @@ -609,29 +641,33 @@ class Database(object): else: return txn.fetchall() - return self.runInteraction(desc, interaction) + return await self.runInteraction(desc, interaction) # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert( + self, + table: str, + values: Dict[str, Any], + or_ignore: bool = False, + desc: str = "simple_insert", + ) -> bool: """Executes an INSERT query on the named table. Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised + table: string giving the table name + values: dict of new column names and values for them + or_ignore: bool stating whether an exception should be raised when a conflicting row already exists. If True, False will be returned by the function instead - desc : string giving a description of the transaction + desc: description of the transaction, for logging and metrics Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True + Whether the row was inserted or not. Only useful when `or_ignore` is True """ try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) + await self.runInteraction(desc, self.simple_insert_txn, table, values) except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -641,7 +677,9 @@ class Database(object): return True @staticmethod - def simple_insert_txn(txn, table, values): + def simple_insert_txn( + txn: LoggingTransaction, table: str, values: Dict[str, Any] + ) -> None: keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -652,11 +690,29 @@ class Database(object): txn.execute(sql, vals) - def simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + async def simple_insert_many( + self, table: str, values: List[Dict[str, Any]], desc: str + ) -> None: + """Executes an INSERT query on the named table. + + Args: + table: string giving the table name + values: dict of new column names and values for them + desc: description of the transaction, for logging and metrics + """ + await self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod - def simple_insert_many_txn(txn, table, values): + def simple_insert_many_txn( + txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] + ) -> None: + """Executes an INSERT query on the named table. + + Args: + txn: The transaction to use. + table: string giving the table name + values: dict of new column names and values for them + """ if not values: return @@ -684,16 +740,15 @@ class Database(object): txn.executemany(sql, vals) - @defer.inlineCallbacks - def simple_upsert( + async def simple_upsert( self, - table, - keyvalues, - values, - insertion_values={}, - desc="simple_upsert", - lock=True, - ): + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + desc: str = "simple_upsert", + lock: bool = True, + ) -> Optional[bool]: """ `lock` should generally be set to True (the default), but can be set @@ -707,21 +762,20 @@ class Database(object): this table. Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. + table: The table to upsert into + keyvalues: The unique key columns and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting + desc: description of the transaction, for logging and metrics + lock: True to lock the table when doing the upsert. Returns: - Deferred(None or bool): Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. + Native upserts always return None. Emulated upserts return True if a + new entry was created, False if an existing one was updated. """ attempts = 0 while True: try: - result = yield self.runInteraction( + return await self.runInteraction( desc, self.simple_upsert_txn, table, @@ -730,7 +784,6 @@ class Database(object): insertion_values, lock=lock, ) - return result except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -744,29 +797,34 @@ class Database(object): ) def simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + lock: bool = True, + ) -> Optional[bool]: """ Pick the UPSERT method which works best on the platform. Either the native one (Pg9.5+, recent SQLites), or fall back to an emulated method. Args: txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. + table: The table to upsert into + keyvalues: The unique key tables and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting + lock: True to lock the table when doing the upsert. Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. + Native upserts always return None. Emulated upserts return True if a + new entry was created, False if an existing one was updated. """ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: - return self.simple_upsert_txn_native_upsert( + self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) + return None else: return self.simple_upsert_txn_emulated( txn, @@ -778,18 +836,23 @@ class Database(object): ) def simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + lock: bool = True, + ) -> bool: """ Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. + table: The table to upsert into + keyvalues: The unique key tables and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting + lock: True to lock the table when doing the upsert. Returns: - bool: Return True if a new entry was created, False if an existing + Returns True if a new entry was created, False if an existing one was updated. """ # We need to lock the table :(, unless we're *really* careful @@ -847,19 +910,21 @@ class Database(object): return True def simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None + table: The table to upsert into + keyvalues: The unique key tables and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting """ allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) @@ -989,41 +1054,93 @@ class Database(object): return txn.execute_batch(sql, args) - def simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" - ): + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: bool = False, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows + table: string giving the table name + keyvalues: dict of column names and values to select the row with + retcols: list of strings giving the names of the columns to return + allow_none: If true, return None instead of failing if the SELECT + statement returns no rows + desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + @overload + async def simple_select_one_onecol( self, - table, - keyvalues, - retcol, - allow_none=False, - desc="simple_select_one_onecol", - ): + table: str, + keyvalues: Dict[str, Any], + retcol: str, + allow_none: Literal[False] = False, + desc: str = "simple_select_one_onecol", + ) -> Any: + ... + + @overload + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: str, + allow_none: Literal[True] = True, + desc: str = "simple_select_one_onecol", + ) -> Optional[Any]: + ... + + async def simple_select_one_onecol( + self, + table: str, + keyvalues: Dict[str, Any], + retcol: str, + allow_none: bool = False, + desc: str = "simple_select_one_onecol", + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return + table: string giving the table name + keyvalues: dict of column names and values to select the row with + retcol: string giving the name of the column to return + allow_none: If true, return None instead of failing if the SELECT + statement returns no rows + desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, @@ -1032,10 +1149,39 @@ class Database(object): allow_none=allow_none, ) + @overload @classmethod def simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: str, + allow_none: Literal[False] = False, + ) -> Any: + ... + + @overload + @classmethod + def simple_select_one_onecol_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: str, + allow_none: Literal[True] = True, + ) -> Optional[Any]: + ... + + @classmethod + def simple_select_one_onecol_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: str, + allow_none: bool = False, + ) -> Optional[Any]: ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -1049,64 +1195,85 @@ class Database(object): raise StoreError(404, "No row found") @staticmethod - def simple_select_onecol_txn(txn, table, keyvalues, retcol): + def simple_select_onecol_txn( + txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str, + ) -> List[Any]: sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) txn.execute(sql, list(keyvalues.values())) else: txn.execute(sql) return [r[0] for r in txn] - def simple_select_onecol( - self, table, keyvalues, retcol, desc="simple_select_onecol" - ): + async def simple_select_onecol( + self, + table: str, + keyvalues: Optional[Dict[str, Any]], + retcol: str, + desc: str = "simple_select_onecol", + ) -> List[Any]: """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. + table: table name + keyvalues: column names and values to select the rows with + retcol: column whos value we wish to retrieve. + desc: description of the transaction, for logging and metrics Returns: - Deferred: Results in a list + Results in a list """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + async def simple_select_list( + self, + table: str, + keyvalues: Optional[Dict[str, Any]], + retcols: Iterable[str], + desc: str = "simple_select_list", + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - table (str): the table name - keyvalues (dict[str, Any] | None): + table: the table name + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return + retcols: the names of the columns to return + desc: description of the transaction, for logging and metrics + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_list_txn, table, keyvalues, retcols ) @classmethod - def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + def simple_select_list_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Optional[Dict[str, Any]], + retcols: Iterable[str], + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): + txn: Transaction object + table: the table name + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return + retcols: the names of the columns to return """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1121,28 +1288,29 @@ class Database(object): return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def simple_select_many_batch( + async def simple_select_many_batch( self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="simple_select_many_batch", - batch_size=100, - ): + table: str, + column: str, + iterable: Iterable[Any], + retcols: Iterable[str], + keyvalues: Dict[str, Any] = {}, + desc: str = "simple_select_many_batch", + batch_size: int = 100, + ) -> List[Any]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + retcols: list of strings giving the names of the columns to return + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + batch_size: the number of rows for each select query """ results = [] # type: List[Dict[str, Any]] @@ -1156,7 +1324,7 @@ class Database(object): it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: - rows = yield self.runInteraction( + rows = await self.runInteraction( desc, self.simple_select_many_txn, table, @@ -1171,19 +1339,27 @@ class Database(object): return results @classmethod - def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + def simple_select_many_txn( + cls, + txn: LoggingTransaction, + table: str, + column: str, + iterable: Iterable[Any], + keyvalues: Dict[str, Any], + retcols: Iterable[str], + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return + txn: Transaction object + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + retcols: list of strings giving the names of the columns to return """ if not iterable: return [] @@ -1191,7 +1367,7 @@ class Database(object): clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) clauses = [clause] - for key, value in iteritems(keyvalues): + for key, value in keyvalues.items(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -1204,15 +1380,26 @@ class Database(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def simple_update(self, table, keyvalues, updatevalues, desc): - return self.runInteraction( + async def simple_update( + self, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + desc: str, + ) -> int: + return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @staticmethod - def simple_update_txn(txn, table, keyvalues, updatevalues): + def simple_update_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + ) -> int: if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) else: where = "" @@ -1226,32 +1413,34 @@ class Database(object): return txn.rowcount - def simple_update_one( - self, table, keyvalues, updatevalues, desc="simple_update_one" - ): + async def simple_update_one( + self, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + desc: str = "simple_update_one", + ) -> None: """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. + table: string giving the table name + keyvalues: dict of column names and values to select the row with + updatevalues: dict giving column names and values to update + desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod - def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + def simple_update_one_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + ) -> None: rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) if rowcount == 0: @@ -1259,8 +1448,18 @@ class Database(object): if rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) + # Ideally we could use the overload decorator here to specify that the + # return type is only optional if allow_none is True, but this does not work + # when you call a static method from an instance. + # See https://github.com/python/mypy/issues/7781 @staticmethod - def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: bool = False, + ) -> Optional[Dict[str, Any]]: select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, @@ -1279,24 +1478,29 @@ class Database(object): return dict(zip(retcols, row)) - def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + async def simple_delete_one( + self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" + ) -> None: """Executes a DELETE query on the named table, expecting to delete a single row. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with + table: string giving the table name + keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics """ - return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod - def simple_delete_one_txn(txn, table, keyvalues): + def simple_delete_one_txn( + txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] + ) -> None: """Executes a DELETE query on the named table, expecting to delete a single row. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with + table: string giving the table name + keyvalues: dict of column names and values to select the row with """ sql = "DELETE FROM %s WHERE %s" % ( table, @@ -1309,11 +1513,38 @@ class Database(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + async def simple_delete( + self, table: str, keyvalues: Dict[str, Any], desc: str + ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics + + Returns: + The number of deleted rows. + """ + return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod - def simple_delete_txn(txn, table, keyvalues): + def simple_delete_txn( + txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] + ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + + Returns: + The number of deleted rows. + """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1322,26 +1553,53 @@ class Database(object): txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def simple_delete_many(self, table, column, iterable, keyvalues, desc): - return self.runInteraction( + async def simple_delete_many( + self, + table: str, + column: str, + iterable: Iterable[Any], + keyvalues: Dict[str, Any], + desc: str, + ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + + Returns: + Number rows deleted + """ + return await self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @staticmethod - def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + def simple_delete_many_txn( + txn: LoggingTransaction, + table: str, + column: str, + iterable: Iterable[Any], + keyvalues: Dict[str, Any], + ) -> int: """Executes a DELETE query on the named table. Filters rows by if value of `column` is in `iterable`. Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with + txn: Transaction object + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with Returns: - int: Number rows deleted + Number rows deleted """ if not iterable: return 0 @@ -1351,7 +1609,7 @@ class Database(object): clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) clauses = [clause] - for key, value in iteritems(keyvalues): + for key, value in keyvalues.items(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -1362,8 +1620,14 @@ class Database(object): return txn.rowcount def get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): + self, + db_conn: Connection, + table: str, + entity_column: str, + stream_column: str, + max_value: int, + limit: int = 100000, + ) -> Tuple[Dict[Any, int], int]: # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. @@ -1388,71 +1652,25 @@ class Database(object): txn.close() if cache: - min_val = min(itervalues(cache)) + min_val = min(cache.values()) else: min_val = max_value return cache, min_val - def simple_select_list_paginate( - self, - table, - orderby, - start, - limit, - retcols, - filters=None, - keyvalues=None, - order_direction="ASC", - desc="simple_select_list_paginate", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table (str): the table name - filters (dict[str, T] | None): - column names and values to filter the rows with, or None to not - apply a WHERE ? LIKE ? clause. - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self.simple_select_list_paginate_txn, - table, - orderby, - start, - limit, - retcols, - filters=filters, - keyvalues=keyvalues, - order_direction=order_direction, - ) - @classmethod def simple_select_list_paginate_txn( cls, - txn, - table, - orderby, - start, - limit, - retcols, - filters=None, - keyvalues=None, - order_direction="ASC", - ): + txn: LoggingTransaction, + table: str, + orderby: str, + start: int, + limit: int, + retcols: Iterable[str], + filters: Optional[Dict[str, Any]] = None, + keyvalues: Optional[Dict[str, Any]] = None, + order_direction: str = "ASC", + ) -> List[Dict[str, Any]]: """ Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, @@ -1463,21 +1681,22 @@ class Database(object): using 'AND'. Args: - txn : Transaction object - table (str): the table name - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - filters (dict[str, T] | None): + txn: Transaction object + table: the table name + orderby: Column to order the results by. + start: Index to begin the query at. + limit: Number of results to return. + retcols: the names of the columns to return + filters: column names and values to filter the rows with, or None to not apply a WHERE ? LIKE ? clause. - keyvalues (dict[str, T] | None): + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - order_direction (str): Whether the results should be ordered "ASC" or "DESC". + order_direction: Whether the results should be ordered "ASC" or "DESC". + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + The result as a list of dictionaries. """ if order_direction not in ["ASC", "DESC"]: raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") @@ -1503,51 +1722,65 @@ class Database(object): return cls.cursor_to_dict(txn) - def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + async def simple_search_list( + self, + table: str, + term: Optional[str], + col: str, + retcols: Iterable[str], + desc="simple_search_list", + ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return + table: the table name + term: term for searching the table matched to a column. + col: column to query term should be matched to + retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None + A list of dictionaries or None. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_search_list_txn, table, term, col, retcols ) @classmethod - def simple_search_list_txn(cls, txn, table, term, col, retcols): + def simple_search_list_txn( + cls, + txn: LoggingTransaction, + table: str, + term: Optional[str], + col: str, + retcols: Iterable[str], + ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return + txn: Transaction object + table: the table name + term: term for searching the table matched to a column. + col: column to query term should be matched to + retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None + None if no term is given, otherwise a list of dictionaries. """ if term: sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) termvalues = ["%%" + term + "%%"] txn.execute(sql, termvalues) else: - return 0 + return None return cls.cursor_to_dict(txn) def make_in_list_sql_clause( - database_engine, column: str, iterable: Iterable + database_engine: BaseDatabaseEngine, column: str, iterable: Iterable ) -> Tuple[str, list]: """Returns an SQL clause that checks the given column is in the iterable. diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py new file mode 100644 index 0000000000..985b12df91 --- /dev/null +++ b/synapse/storage/databases/__init__.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database + +logger = logging.getLogger(__name__) + + +class Databases: + """The various databases. + + These are low level interfaces to physical databases. + + Attributes: + main (DataStore) + """ + + def __init__(self, main_store_class, hs): + # Note we pass in the main store class here as workers use a different main + # store. + + self.databases = [] + main = None + state = None + persist_events = None + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("[database config %r]: Checking database server", db_name) + engine.check_database(db_conn) + + logger.info( + "[database config %r]: Preparing for databases %r", + db_name, + database_config.databases, + ) + prepare_database( + db_conn, engine, hs.config, databases=database_config.databases, + ) + + database = DatabasePool(hs, database_config, engine) + + if "main" in database_config.databases: + logger.info( + "[database config %r]: Starting 'main' database", db_name + ) + + # Sanity check we don't try and configure the main store on + # multiple databases. + if main: + raise Exception("'main' data store already configured") + + main = main_store_class(database, db_conn, hs) + + # If we're on a process that can persist events also + # instantiate a `PersistEventsStore` + if hs.config.worker.writers.events == hs.get_instance_name(): + persist_events = PersistEventsStore(hs, database, main) + + if "state" in database_config.databases: + logger.info( + "[database config %r]: Starting 'state' database", db_name + ) + + # Sanity check we don't try and configure the state store on + # multiple databases. + if state: + raise Exception("'state' data store already configured") + + state = StateGroupDataStore(database, db_conn, hs) + + db_conn.commit() + + self.databases.append(database) + + logger.info("[database config %r]: prepared", db_name) + + # Closing the context manager doesn't close the connection. + # psycopg will close the connection when the object gets GCed, but *only* + # if the PID is the same as when the connection was opened [1], and + # it may not be if we fork in the meantime. + # + # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378 + + db_conn.close() + + # Sanity check that we have actually configured all the required stores. + if not main: + raise Exception("No 'main' database configured") + + if not state: + raise Exception("No 'state' database configured") + + # We use local variables here to ensure that the databases do not have + # optional types. + self.main = main + self.state = state + self.persist_events = persist_events diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/databases/main/__init__.py index 4b4763c701..2ae2fbd5d7 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,16 +18,18 @@ import calendar import logging import time +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( IdGenerator, MultiWriterIdGenerator, StreamIdGenerator, ) +from synapse.types import get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore @@ -119,7 +121,7 @@ class DataStore( CacheInvalidationWorkerStore, ServerMetricsStore, ): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine @@ -128,7 +130,7 @@ class DataStore( db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" @@ -174,7 +176,7 @@ class DataStore( self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self.db.get_cache_dict( + presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -188,7 +190,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -203,7 +205,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -229,7 +231,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -243,7 +245,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -263,6 +265,9 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None @@ -282,7 +287,7 @@ class DataStore( txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) txn.close() for row in rows: @@ -290,14 +295,16 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - def count_daily_users(self): + async def count_daily_users(self) -> int: """ Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db.runInteraction("count_daily_users", self._count_users, yesterday) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) - def count_monthly_users(self): + async def count_monthly_users(self) -> int: """ Counts the number of users who used this homeserver in the last 30 days. Note this method is intended for phonehome metrics only and is different @@ -305,7 +312,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -324,15 +331,15 @@ class DataStore( (count,) = txn.fetchone() return count - def count_r30_users(self): + async def count_r30_users(self) -> Dict[str, int]: """ Counts the number of 30 day retained users, defined as:- * Users who have created their accounts more than 30 days ago * Where last seen at most 30 days ago * Where account creation and last_seen are > 30 days apart - Returns counts globaly for a given user as well as breaking - by platform + Returns: + A mapping of counts globally as well as broken out by platform. """ def _count_r30_users(txn): @@ -405,7 +412,7 @@ class DataStore( return results - return self.db.runInteraction("count_r30_users", _count_r30_users) + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -415,7 +422,7 @@ class DataStore( today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 - def generate_user_daily_visits(self): + async def generate_user_daily_visits(self) -> None: """ Generates daily visit data for use in cohort/ retention analysis """ @@ -470,18 +477,17 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.db.runInteraction( + await self.db_pool.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) - def get_users(self): + async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. - Args: Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries representing users. """ - return self.db.simple_select_list( + return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ @@ -495,30 +501,40 @@ class DataStore( desc="get_users", ) - def get_users_paginate( - self, start, limit, name=None, guests=True, deactivated=False - ): + async def get_users_paginate( + self, + start: int, + limit: int, + user_id: Optional[str] = None, + name: Optional[str] = None, + guests: bool = True, + deactivated: bool = False, + ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - name (string): filter for user names - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users + start: start number to begin the query from + limit: number of rows to retrieve + user_id: search for user_id. ignored if name is not None + name: search for local part of user_id or display name + guests: whether to in include guest users + deactivated: whether to include deactivated users Returns: - defer.Deferred: resolves to list[dict[str, Any]], int + A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn(txn): filters = [] - args = [] + args = [self.hs.config.server_name] if name: + filters.append("(name LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + name + "%:%", "%" + name + "%"]) + elif user_id: filters.append("name LIKE ?") - args.append("%" + name + "%") + args.extend(["%" + user_id + "%"]) if not guests: filters.append("is_guest = 0") @@ -528,37 +544,42 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) - txn.execute(sql, args) - count = txn.fetchone()[0] - - args = [self.hs.config.server_name] + args + [limit, start] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + sql_base = """ FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} - ORDER BY u.name LIMIT ? OFFSET ? """.format( where_clause ) + sql = "SELECT COUNT(*) as total_users " + sql_base + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = ( + "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + + sql_base + + " ORDER BY u.name LIMIT ? OFFSET ?" + ) + args += [limit, start] txn.execute(sql, args) - users = self.db.cursor_to_dict(txn) + users = self.db_pool.cursor_to_dict(txn) return users, count - return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) + return await self.db_pool.runInteraction( + "get_users_paginate_txn", get_users_paginate_txn + ) - def search_users(self, term): + async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: """Function to search users list for one or more users with the matched term. Args: - term (str): search term - col (str): column to query term should be matched to + term: search term + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries or None. """ - return self.db.simple_search_list( + return await self.db_pool.simple_search_list( table="users", term=term, col="name", @@ -571,21 +592,24 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig """Called before upgrading an existing database to check that it is broadly sane compared with the configuration. """ - domain = config.server_name + logger.info("Checking database for consistency with configuration...") - sql = database_engine.convert_param_style( - "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" - ) - pat = "%:" + domain - cur.execute(sql, (pat,)) - num_not_matching = cur.fetchall()[0][0] - if num_not_matching == 0: + # if there are any users in the database, check that the username matches our + # configured server name. + + cur.execute("SELECT name FROM users LIMIT 1") + rows = cur.fetchall() + if not rows: + return + + user_domain = get_domain_from_id(rows[0][0]) + if user_domain == config.server_name: return raise Exception( "Found users in database not native to %s!\n" - "You cannot changed a synapse server_name after it's been configured" - % (domain,) + "You cannot change a synapse server_name after it's been configured" + % (config.server_name,) ) diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/databases/main/account_data.py index b58f04d00d..4436b1a83d 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,16 +16,14 @@ import abc import logging -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.types import JsonDict +from synapse.util import json_encoder +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -40,7 +38,7 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max @@ -58,18 +56,20 @@ class AccountDataWorkerStore(SQLBaseStore): raise NotImplementedError() @cached() - def get_account_data_for_user(self, user_id): + async def get_account_data_for_user( + self, user_id: str + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a user. Args: - user_id(str): The user to get the account_data for. + user_id: The user to get the account_data for. Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. + A 2-tuple of a dict of global account_data and a dict mapping from + room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -77,10 +77,10 @@ class AccountDataWorkerStore(SQLBaseStore): ) global_account_data = { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -90,21 +90,23 @@ class AccountDataWorkerStore(SQLBaseStore): by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) - room_data[row["account_data_type"]] = json.loads(row["content"]) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): + @cached(num_args=2, max_entries=5000) + async def get_global_account_data_by_type_for_user( + self, data_type: str, user_id: str + ) -> Optional[JsonDict]: """ Returns: - Deferred: A dict + The account data. """ - result = yield self.db.simple_select_one_onecol( + result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -113,23 +115,25 @@ class AccountDataWorkerStore(SQLBaseStore): ) if result: - return json.loads(result) + return db_to_json(result) else: return None @cached(num_args=2) - def get_account_data_for_room(self, user_id, room_id): + async def get_account_data_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the client account_data for a user for a room. Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. + user_id: The user to get the account_data for. + room_id: The room to get the account_data for. Returns: - A deferred dict of the room account_data + A dict of the room account_data """ def get_account_data_for_room_txn(txn): - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -137,28 +141,29 @@ class AccountDataWorkerStore(SQLBaseStore): ) return { - row["account_data_type"]: json.loads(row["content"]) for row in rows + row["account_data_type"]: db_to_json(row["content"]) for row in rows } - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) - def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + async def get_account_data_for_room_and_type( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[JsonDict]: """Get the client account_data of given type for a user for a room. Args: - user_id(str): The user to get the account_data for. - room_id(str): The room to get the account_data for. - account_data_type (str): The account data type to get. + user_id: The user to get the account_data for. + room_id: The room to get the account_data for. + account_data_type: The account data type to get. Returns: - A deferred of the room account_data for that type, or None if - there isn't any set. + The room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): - content_json = self.db.simple_select_one_onecol_txn( + content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -170,9 +175,9 @@ class AccountDataWorkerStore(SQLBaseStore): allow_none=True, ) - return json.loads(content_json) if content_json else None + return db_to_json(content_json) if content_json else None - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -202,7 +207,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn ) @@ -232,16 +237,18 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id): + async def get_updated_account_data_for_user( + self, user_id: str, stream_id: int + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: - user_id(str): The user to get the account_data for. - stream_id(int): The point in the stream since which to get updates + user_id: The user to get the account_data for. + stream_id: The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. @@ -255,7 +262,7 @@ class AccountDataWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) - global_account_data = {row[0]: json.loads(row[1]) for row in txn} + global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" @@ -267,7 +274,7 @@ class AccountDataWorkerStore(SQLBaseStore): account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) - room_account_data[row[1]] = json.loads(row[2]) + room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room @@ -275,15 +282,17 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return defer.succeed(({}, {})) + return ({}, {}) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( + @cached(num_args=2, cache_context=True, max_entries=5000) + async def is_ignored_by( + self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext + ) -> bool: + ignored_account_data = await self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, @@ -295,7 +304,7 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", @@ -308,32 +317,35 @@ class AccountDataStore(AccountDataWorkerStore): super(AccountDataStore, self).__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream Returns: - A deferred int. + The maximum stream ID. """ return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -351,7 +363,7 @@ class AccountDataStore(AccountDataWorkerStore): # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -360,26 +372,28 @@ class AccountDataStore(AccountDataWorkerStore): (user_id, room_id, account_data_type), content ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -397,7 +411,7 @@ class AccountDataStore(AccountDataWorkerStore): # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -405,14 +419,13 @@ class AccountDataStore(AccountDataWorkerStore): (account_data_type, user_id) ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id): + async def _update_max_stream_id(self, next_id: int) -> None: """Update the max stream_id Args: - next_id(int): The the revision to advance to. + next_id: The the revision to advance to. """ # Note: This is only here for backwards compat to allow admins to @@ -427,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.db.runInteraction("update_account_data_max_stream_id", _update) + await self.db_pool.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/databases/main/appservice.py index 7a1fe8cdd2..454c0bc50c 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -16,15 +16,12 @@ import logging import re -from canonicaljson import json - -from twisted.internet import defer - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -49,7 +46,7 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -124,17 +121,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): + async def get_appservices_by_state(self, state): """Get a list of application services based on their state. Args: state(ApplicationServiceState): The state to filter on. Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. + A list of ApplicationServices, which may be empty. """ - results = yield self.db.simple_select_list( + results = await self.db_pool.simple_select_list( "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -147,16 +142,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - @defer.inlineCallbacks - def get_appservice_state(self, service): + async def get_appservice_state(self, service): """Get the application service state. Args: service(ApplicationService): The service whose state to set. Returns: - A Deferred which resolves to ApplicationServiceState. + An ApplicationServiceState. """ - result = yield self.db.simple_select_one( + result = await self.db_pool.simple_select_one( "application_services_state", {"as_id": service.id}, ["state"], @@ -167,20 +161,18 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - def set_appservice_state(self, service, state): + async def set_appservice_state(self, service, state) -> None: """Set the application service state. Args: service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. - Returns: - A Deferred which resolves when the state was set successfully. """ - return self.db.simple_upsert( + await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) - def create_appservice_txn(self, service, events): + async def create_appservice_txn(self, service, events): """Atomically creates a new transaction for this application service with the given list of events. @@ -209,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore( new_txn_id = max(highest_txn_id, last_txn_id) + 1 # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) + event_ids = json_encoder.encode([e.event_id for e in events]) txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", @@ -217,18 +209,17 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) + return await self.db_pool.runInteraction( + "create_appservice_txn", _create_appservice_txn + ) - def complete_appservice_txn(self, txn_id, service): + async def complete_appservice_txn(self, txn_id, service) -> None: """Completes an application service transaction. Args: txn_id(str): The transaction ID being completed. service(ApplicationService): The application service which was sent this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. """ txn_id = int(txn_id) @@ -250,7 +241,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, "application_services_state", {"as_id": service.id}, @@ -258,26 +249,24 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "application_services_txns", {"txn_id": txn_id, "as_id": service.id}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): + async def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. Args: service(ApplicationService): The app service to get the oldest txn. Returns: - A Deferred which resolves to an AppServiceTransaction or - None. + An AppServiceTransaction or None. """ def _get_oldest_unsent_txn(txn): @@ -288,7 +277,7 @@ class ApplicationServiceTransactionWorkerStore( " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return None @@ -296,16 +285,16 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.db.runInteraction( + entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) if not entry: return None - event_ids = json.loads(entry["event_ids"]) + event_ids = db_to_json(entry["event_ids"]) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -320,18 +309,17 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - def set_appservice_last_pos(self, pos): + async def set_appservice_last_pos(self, pos) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice(self, current_id, limit): """Get all new evnets""" def get_new_events_for_appservice_txn(txn): @@ -355,11 +343,11 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/databases/main/cache.py index eac5a4e55b..1e7637a6f5 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -16,15 +16,17 @@ import itertools import logging -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams.events import ( + EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -37,20 +39,37 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int - ): - """Fetches cache invalidation rows between the two given IDs written - by the given instance. Returns at most `limit` rows. + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return [] + return [], current_id, False def get_all_updated_caches_txn(txn): # We purposefully don't bound by the current token, as we want to @@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, instance_name, limit)) - return txn.fetchall() + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": + if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) - elif stream_name == "backfill": + elif stream_name == BackfillStream.NAME: for row in rows: self._invalidate_caches_for_event( -token, @@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): row.relates_to, backfilled=True, ) - elif stream_name == "caches": + elif stream_name == CachesStream.NAME: if self._cache_id_gen: self._cache_id_gen.advance(instance_name, token) @@ -176,7 +202,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db.runInteraction( + await self.db_pool.runInteraction( "invalidate_cache_and_stream", self._send_invalidation_to_replication, cache_func.__name__, @@ -261,7 +287,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if keys is not None: keys = list(keys) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="cache_invalidation_stream_by_instance", values={ @@ -273,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) - def get_cache_stream_token(self, instance_name): + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: - return self._cache_id_gen.get_current_token(instance_name) + return self._cache_id_gen.get_current_token_for_writer(instance_name) else: return 0 diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 2d48261724..f211ddbaf8 100644 --- a/synapse/storage/data_stores/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -16,15 +16,13 @@ import logging from typing import TYPE_CHECKING -from twisted.internet import defer - from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.data_stores.main.events import encode_json -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.events import encode_json +from synapse.storage.databases.main.events_worker import EventsWorkerStore if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,7 +32,7 @@ logger = logging.getLogger(__name__) class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): - def __init__(self, database: Database, db_conn, hs: "HomeServer"): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) def _censor_redactions(): @@ -56,7 +54,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return if not ( - await self.db.updates.has_completed_background_update( + await self.db_pool.updates.has_completed_background_update( "redactions_have_censored_ts_idx" ) ): @@ -85,7 +83,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase LIMIT ? """ - rows = await self.db.execute( + rows = await self.db_pool.execute( "_censor_redactions_fetch", None, sql, before_ts, 100 ) @@ -123,14 +121,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase if pruned_json: self._censor_event_txn(txn, event_id, pruned_json) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, updatevalues={"have_censored": True}, ) - await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) def _censor_event_txn(self, txn, event_id, pruned_json): """Censor an event by replacing its JSON in the event_json table with the @@ -141,24 +139,23 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase event_id (str): The ID of the event to censor. pruned_json (str): The pruned JSON """ - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="event_json", keyvalues={"event_id": event_id}, updatevalues={"json": pruned_json}, ) - @defer.inlineCallbacks - def expire_event(self, event_id): + async def expire_event(self, event_id: str) -> None: """Retrieve and expire an event that has expired, and delete its associated expiry timestamp. If the event can't be retrieved, delete its associated timestamp so we don't try to expire it again in the future. Args: - event_id (str): The ID of the event to delete. + event_id: The ID of the event to delete. """ # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) + event = await self.get_event(event_id) def delete_expired_event_txn(txn): # Delete the expiry timestamp associated with this event from the database. @@ -193,7 +190,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase txn, "_get_event_cache", (event.event_id,) ) - yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + await self.db_pool.runInteraction( + "delete_expired_event", delete_expired_event_txn + ) def _delete_event_expiry_txn(self, txn, event_id): """Delete the expiry timestamp associated with an event ID without deleting the @@ -203,6 +202,6 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase txn (LoggingTransaction): The transaction to use to perform the deletion. event_id (str): The event ID to delete the associated expiry timestamp of. """ - return self.db.simple_delete_txn( + return self.db_pool.simple_delete_txn( txn=txn, table="event_expiry", keyvalues={"event_id": event_id} ) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 71f8d43a76..c2fc847fbc 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -14,14 +14,11 @@ # limitations under the License. import logging - -from six import iteritems - -from twisted.internet import defer +from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database, make_tuple_comparison_clause +from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -33,40 +30,40 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -75,28 +72,28 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ) # Drop the old non-unique index - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) - @defer.inlineCallbacks - def _remove_user_ip_nonunique(self, progress, batch_size): + async def _remove_user_ip_nonunique(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.db.runWithConnection(f) - yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( + "user_ips_drop_nonunique_index" + ) return 1 - @defer.inlineCallbacks - def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress, batch_size): # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. @@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) + await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) - yield self.db.updates._end_background_update("user_ips_analyze") + await self.db_pool.updates._end_background_update("user_ips_analyze") return 1 - @defer.inlineCallbacks - def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress, batch_size): # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they @@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.db.runInteraction( + end_last_seen = await self.db_pool.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -271,19 +267,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.db.runInteraction("user_ips_dups_remove", remove) + await self.db_pool.runInteraction("user_ips_dups_remove", remove) if last: - yield self.db.updates._end_background_update("user_ips_remove_dupes") + await self.db_pool.updates._end_background_update("user_ips_remove_dupes") return batch_size - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update(self, progress, batch_size): """Background update to insert last seen info into devices table """ @@ -338,7 +333,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -346,18 +341,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return len(rows) - updated = yield self.db.runInteraction( + updated = await self.db_pool.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) if not updated: - yield self.db.updates._end_background_update("devices_last_seen") + await self.db_pool.updates._end_background_update("devices_last_seen") return updated class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 @@ -380,8 +375,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks - def insert_client_ip( + async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): if not now: @@ -392,7 +386,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self.populate_monthly_active_users(user_id) + await self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return @@ -402,30 +396,30 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating - if not self.db.is_running(): + if not self.db_pool.is_running(): return to_update = self._batch_row_update self._batch_row_update = {} - return self.db.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self.db._unsafe_to_upsert_tables or ( + if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") - for entry in iteritems(to_update): + for entry in to_update.items(): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -447,7 +441,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # this is always an update rather than an upsert: the row should # already exist, and if it doesn't, that may be because it has been # deleted, and we don't want to re-create it. - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -461,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Failed to upsert, log and continue logger.error("Failed to insert client IP %r: %r", entry, e) - @defer.inlineCallbacks - def get_last_client_ip_by_device(self, user_id, device_id): + async def get_last_client_ip_by_device( + self, user_id: str, device_id: Optional[str] + ) -> Dict[Tuple[str, str], dict]: """For each device_id listed, give the user_ip it was last seen on Args: - user_id (str) - device_id (str): If None fetches all devices for the user + user_id: The user to fetch devices for. + device_id: If None fetches all devices for the user Returns: - defer.Deferred: resolves to a dict, where the keys - are (user_id, device_id) tuples. The values are also dicts, with - keys giving the column names + A dictionary mapping a tuple of (user_id, device_id) to dicts, with + keys giving the column names from the devices table. """ keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id - res = yield self.db.simple_select_list( + res = await self.db_pool.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -501,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } return ret - @defer.inlineCallbacks - def get_user_ip_and_agents(self, user): + async def get_user_ip_and_agents(self, user): user_id = user.to_string() results = {} @@ -512,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -530,7 +523,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "user_agent": user_agent, "last_seen": last_seen, } - for (access_token, ip), (user_agent, last_seen) in iteritems(results) + for (access_token, ip), (user_agent, last_seen) in results.items() ] @wrap_as_background_process("prune_old_user_ips") @@ -542,7 +535,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Nothing to do return - if not await self.db.updates.has_completed_background_update( + if not await self.db_pool.updates.has_completed_background_update( "devices_last_seen" ): # Only start pruning if we have finished populating the devices @@ -575,4 +568,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _prune_old_user_ips_txn(txn): txn.execute(sql, (timestamp,)) - await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn + ) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 9a1178fb39..0044433110 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,14 +14,12 @@ # limitations under the License. import logging - -from canonicaljson import json - -from twisted.internet import defer +from typing import List, Tuple from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): + async def get_new_messages_for_device( + self, + user_id: str, + device_id: str, + last_stream_id: int, + current_stream_id: int, + limit: int = 100, + ) -> Tuple[List[dict], int]: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device + user_id: The recipient user_id. + device_id: The recipient device_id. + last_stream_id: The last stream ID checked. + current_stream_id: The current position of the to device message stream. + limit: The maximum number of messages to retrieve. + Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id ) if not has_changed: - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ( @@ -64,25 +69,27 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + async def delete_messages_for_device( + self, user_id: str, device_id: str, up_to_stream_id: int + ) -> int: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. + user_id: The recipient user_id. + device_id: The recipient device_id. + up_to_stream_id: Where to delete messages up to. + Returns: - A deferred that resolves to the number of messages deleted. + The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting @@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.db.runInteraction( + count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): return count @trace - def get_new_device_msgs_for_remote( + async def get_new_device_msgs_for_remote( self, destination, last_stream_id, current_stream_id, limit - ): + ) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. @@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) @@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) + return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): @@ -172,27 +178,27 @@ class DeviceInboxWorkerStore(SQLBaseStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(json.loads(row[1])) + messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id return messages, stream_pos - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + async def delete_device_msgs_for_remote( + self, destination: str, up_to_stream_id: int + ) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. + destination: The destination server_name + up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): @@ -203,35 +209,50 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db.runInteraction( + await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ + async def get_all_new_device_messages( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for to device replication stream. + Args: - last_pos(int): - current_pos(int): - limit(int): + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + Returns: - A deferred list of rows from the device inbox + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ - if last_pos == current_pos: - return defer.succeed([]) + + if last_id == current_id: + return [], current_id, False def get_all_new_device_messages_txn(txn): # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. - upper_pos = min(current_pos, last_pos + limit) + upper_pos = min(current_id, last_id + limit) sql = ( "SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id" ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() + txn.execute(sql, (last_id, upper_pos)) + updates = [(row[0], row[1:]) for row in txn] sql = ( "SELECT max(stream_id), destination" @@ -239,15 +260,21 @@ class DeviceInboxWorkerStore(SQLBaseStore): " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination" ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) + txn.execute(sql, (last_id, upper_pos)) + updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering - rows.sort() + updates.sort() + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - return rows + return updates, upto_token, limited - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) @@ -255,30 +282,29 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.db.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -286,7 +312,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been @@ -299,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): + async def add_messages_to_device_inbox( + self, + local_messages_by_user_then_device: dict, + remote_messages_by_destination: dict, + ) -> int: """Used to send messages from this server. Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): + local_messages_by_user_and_device: Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): + remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. + Returns: - A deferred stream_id that resolves when the messages have been - inserted. + The new stream_id. """ def add_messages_txn(txn, now_ms, stream_id): @@ -332,13 +358,13 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) + edu_json = json_encoder.encode(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -350,15 +376,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) return self._device_inbox_id_gen.get_current_token() - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): + async def add_messages_from_remote_to_device_inbox( + self, origin: str, message_id: str, local_messages_by_user_then_device: dict + ) -> int: def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self.db.simple_select_one_txn( + already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -370,7 +395,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -386,9 +411,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -402,9 +427,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" - txn.execute(sql, (stream_id, stream_id)) - local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} @@ -413,7 +435,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Handle wildcard device_ids. sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) + message_json = json_encoder.encode(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -435,7 +457,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = json.dumps(messages_by_device[device]) + message_json = json_encoder.encode(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/databases/main/devices.py index fb9f798e29..add4e3ea0e 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,14 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging -from typing import List, Optional, Set, Tuple - -from six import iteritems - -from canonicaljson import json - -from twisted.internet import defer +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -33,17 +28,13 @@ from synapse.logging.opentracing import ( from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( - Database, + DatabasePool, LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, get_verify_key_from_cross_signing_key -from synapse.util.caches.descriptors import ( - Cache, - cached, - cachedInlineCallbacks, - cachedList, -) +from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key +from synapse.util import json_encoder +from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -57,38 +48,36 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", ) - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. Args: - user_id (str): + user_id: Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. + A mapping from device_id to a dict containing "device_id", "user_id" + and "display_name" for each device. """ - devices = yield self.db.simple_select_list( + devices = await self.db_pool.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -98,21 +87,22 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} @trace - @defer.inlineCallbacks - def get_device_updates_by_remote(self, destination, from_stream_id, limit): + async def get_device_updates_by_remote( + self, destination: str, from_stream_id: int, limit: int + ) -> Tuple[int, List[Tuple[str, dict]]]: """Get a stream of device updates to send to the given remote server. Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - limit (int): Maximum number of device updates to return + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + limit: Maximum number of device updates to return + Returns: - Deferred[tuple[int, list[tuple[string,dict]]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates, where each update is a pair of EDU - type and EDU contents + A mapping from the current stream id (ie, the stream id of the last + update included in the response), and the list of updates, where + each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -120,7 +110,7 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - updates = yield self.db.runInteraction( + updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -139,7 +129,7 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + cross_signing_key = await self.get_e2e_cross_signing_key(user, "master") if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -152,7 +142,7 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield self.get_e2e_cross_signing_key( + cross_signing_key = await self.get_e2e_cross_signing_key( user, "self_signing" ) if cross_signing_key: @@ -203,12 +193,12 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - results = yield self._get_device_update_edus_by_remote( + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) # add the updated cross-signing keys to the results list - for user_id, result in iteritems(cross_signing_keys_by_user): + for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec results.append(("org.matrix.signing_key_update", result)) @@ -216,16 +206,21 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, results def _get_device_updates_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit + self, + txn: LoggingTransaction, + destination: str, + from_stream_id: int, + now_stream_id: int, + limit: int, ): """Return device update information for a given remote destination Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return + txn: The transaction to execute + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + now_stream_id: The maximum stream_id to filter updates by, inclusive + limit: Maximum number of device updates to return Returns: List: List of device updates @@ -241,25 +236,26 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + async def _get_device_update_edus_by_remote( + self, + destination: str, + from_stream_id: int, + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]], + ) -> List[Tuple[str, dict]]: """Returns a list of device update EDUs as well as E2EE keys Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded + user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: - List[Dict]: List of objects representing an device update EDU - + List of objects representing an device update EDU """ devices = ( - yield self.db.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, + await self.get_e2e_device_keys_and_signatures( query_map.keys(), include_all_devices=True, include_deleted_devices=True, @@ -269,10 +265,10 @@ class DeviceWorkerStore(SQLBaseStore): ) results = [] - for user_id, user_devices in iteritems(devices): + for user_id, user_devices in devices.items(): # The prev_id for the first row is always the last row before # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( + prev_id = await self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) @@ -295,17 +291,11 @@ class DeviceWorkerStore(SQLBaseStore): prev_id = stream_id if device is not None: - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) + keys = device.keys + if keys: + result["keys"] = keys - device_display_name = device.get("device_display_name", None) + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: @@ -315,9 +305,9 @@ class DeviceWorkerStore(SQLBaseStore): return results - def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id - ): + async def _get_last_device_update_for_remote_user( + self, destination: str, user_id: str, from_stream_id: int + ) -> int: def f(txn): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id @@ -328,19 +318,25 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.db.runInteraction("get_last_device_update_for_remote_user", f) + return await self.db_pool.runInteraction( + "get_last_device_update_for_remote_user", f + ) - def mark_as_sent_devices_by_remote(self, destination, stream_id): + async def mark_as_sent_devices_by_remote( + self, destination: str, stream_id: int + ) -> None: """Mark that updates have successfully been sent to the destination. """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, stream_id, ) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + def _mark_as_sent_devices_by_remote_txn( + self, txn: LoggingTransaction, destination: str, stream_id: int + ) -> None: # We update the device_lists_outbound_last_success with the successfully # poked users. sql = """ @@ -352,7 +348,7 @@ class DeviceWorkerStore(SQLBaseStore): txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() - self.db.simple_upsert_many_txn( + self.db_pool.simple_upsert_many_txn( txn=txn, table="device_lists_outbound_last_success", key_names=("destination", "user_id"), @@ -368,17 +364,21 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) - @defer.inlineCallbacks - def add_user_signature_change_to_streams(self, from_user_id, user_ids): + async def add_user_signature_change_to_streams( + self, from_user_id: str, user_ids: List[str] + ) -> int: """Persist that a user has made new signatures Args: - from_user_id (str): the user who made the signatures - user_ids (list[str]): the users who were signed + from_user_id: the user who made the signatures + user_ids: the users who were signed + + Returns: + THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: - yield self.db.runInteraction( + with await self._device_list_id_gen.get_next() as stream_id: + await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -387,45 +387,54 @@ class DeviceWorkerStore(SQLBaseStore): ) return stream_id - def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + def _add_user_signature_change_txn( + self, + txn: LoggingTransaction, + from_user_id: str, + user_ids: List[str], + stream_id: int, + ) -> None: txn.call_after( self._user_signature_stream_cache.entity_has_changed, from_user_id, stream_id, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "user_signature_stream", values={ "stream_id": stream_id, "from_user_id": from_user_id, - "user_ids": json.dumps(user_ids), + "user_ids": json_encoder.encode(user_ids), }, ) - def get_device_stream_token(self): - return self._device_list_id_gen.get_current_token() + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): + async def get_user_devices_from_cache( + self, query_list: List[Tuple[str, str]] + ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list(list): List of (user_id, device_ids), if device_ids is + query_list: List of (user_id, device_ids), if device_ids is falsey then return all device ids for that user. Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info + A tuple of (user_ids_not_in_cache, results_map), where + user_ids_not_in_cache is a set of user_ids and results_map is a + mapping of user_id -> device_id -> device_info. """ user_ids = {user_id for user_id, _ in query_list} - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. - users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( user_ids ) user_ids_in_cache = { @@ -439,19 +448,19 @@ class DeviceWorkerStore(SQLBaseStore): continue if device_id: - device = yield self._get_cached_user_device(user_id, device_id) + device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = yield self.get_cached_devices_for_user(user_id) + results[user_id] = await self.get_cached_devices_for_user(user_id) set_tag("in_cache", results) set_tag("not_in_cache", user_ids_not_in_cache) return user_ids_not_in_cache, results - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self.db.simple_select_one_onecol( + @cached(num_args=2, tree=True) + async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -459,9 +468,9 @@ class DeviceWorkerStore(SQLBaseStore): ) return db_to_json(content) - @cachedInlineCallbacks() - def get_cached_devices_for_user(self, user_id): - devices = yield self.db.simple_select_list( + @cached() + async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -471,62 +480,18 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id): - """Get all devices (with any device keys) for a user - - Returns: - (stream_id, devices) - """ - return self.db.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn(self, txn, user_id): - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn( - txn, [(user_id, None)], include_all_devices=True - ) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in iteritems(user_devices): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - - def get_users_whose_devices_changed(self, from_key, user_ids): + async def get_users_whose_devices_changed( + self, from_key: str, user_ids: Iterable[str] + ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) + from_key: The device lists stream token + user_ids: The user IDs to query for devices. Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` + The set of user_ids whose devices have changed since `from_key` """ from_key = int(from_key) @@ -537,7 +502,7 @@ class DeviceWorkerStore(SQLBaseStore): ) if not to_check: - return defer.succeed(set()) + return set() def _get_users_whose_devices_changed_txn(txn): changes = set() @@ -557,18 +522,22 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - @defer.inlineCallbacks - def get_users_whose_signatures_changed(self, user_id, from_key): + async def get_users_whose_signatures_changed( + self, user_id: str, from_key: str + ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. Args: - user_id (str): the user who made the signatures - from_key (str): The device lists stream token + user_id: the user who made the signatures + from_key: The device lists stream token + + Returns: + A set of user IDs with updated signatures. """ from_key = int(from_key) if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): @@ -576,48 +545,76 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.db.execute( + rows = await self.db_pool.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return {user for row in rows for user in json.loads(row[0])} + return {user for row in rows for user in db_to_json(row[0])} else: return set() async def get_all_device_list_changes_for_remotes( - self, from_key: int, to_key: int, limit: int, - ) -> List[Tuple[int, str]]: - """Return a list of `(stream_id, entity)` which is the combined list of - changes to devices and which destinations need to be poked. Entity is - either a user ID (starting with '@') or a remote destination. - """ + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for device lists replication stream. - # This query Does The Right Thing where it'll correctly apply the - # bounds to the inner queries. - sql = """ - SELECT stream_id, entity FROM ( - SELECT stream_id, user_id AS entity FROM device_lists_stream - UNION ALL - SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes - ) AS e - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updates. + + The updates are a list of 2-tuples of stream ID and the row data """ - return await self.db.execute( + if last_id == current_id: + return [], current_id, False + + def _get_all_device_list_changes_for_remotes(txn): + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. + sql = """ + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( "get_all_device_list_changes_for_remotes", - None, - sql, - from_key, - to_key, - limit, + _get_all_device_list_changes_for_remotes, ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,10 +625,9 @@ class DeviceWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", - inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self.db.simple_select_many_batch( + async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -644,8 +640,7 @@ class DeviceWorkerStore(SQLBaseStore): return results - @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync( + async def get_user_ids_requiring_device_list_resync( self, user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we @@ -656,7 +651,7 @@ class DeviceWorkerStore(SQLBaseStore): The IDs of users whose device lists need resync. """ if user_ids: - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", column="user_id", iterable=user_ids, @@ -664,7 +659,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_user_ids_requiring_device_list_resync_with_iterable", ) else: - rows = yield self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, retcols=("user_id",), @@ -673,11 +668,11 @@ class DeviceWorkerStore(SQLBaseStore): return {row["user_id"] for row in rows} - def mark_remote_user_device_cache_as_stale(self, user_id: str): + async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, @@ -685,12 +680,12 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id): + async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user. """ def _mark_remote_user_device_list_as_unsubscribed_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -699,17 +694,17 @@ class DeviceWorkerStore(SQLBaseStore): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", _mark_remote_user_device_list_as_unsubscribed_txn, ) class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -717,7 +712,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) # create a unique index on device_lists_remote_cache - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -726,7 +721,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) # And one on device_lists_remote_extremeties - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -735,35 +730,34 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) # once they complete, we can remove the old non-unique indexes. - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) # clear out duplicate device list outbound pokes - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, ) # a pair of background updates that were added during the 1.14 release cycle, # but replaced with 58/06dlols_unique_idx.py - self.db.updates.register_noop_background_update( + self.db_pool.updates.register_noop_background_update( "device_lists_outbound_last_success_unique_idx", ) - self.db.updates.register_noop_background_update( + self.db_pool.updates.register_noop_background_update( "drop_device_lists_outbound_last_success_non_unique_idx", ) - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.db.runWithConnection(f) - yield self.db.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES ) return 1 @@ -783,7 +777,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn): clause, args = make_tuple_comparison_clause( - self.db.engine, [(x, last_row[x]) for x in KEY_COLS] + self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS] ) sql = """ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts @@ -799,30 +793,32 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ",".join(KEY_COLS), # ORDER BY ) txn.execute(sql, args + [batch_size]) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) row = None for row in rows: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, ) row["sent"] = False - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", row, ) if row: - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, ) return len(rows) - rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) + rows = await self.db_pool.runInteraction( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn + ) if not rows: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES ) @@ -830,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies @@ -841,18 +837,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): + async def store_device( + self, user_id: str, device_id: str, initial_device_display_name: str + ) -> bool: """Ensure the given device is known; add it to the store if not Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. + user_id: id of user associated with the device + device_id: id of device + initial_device_display_name: initial displayname of the device. + Ignored if device exists. + Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. + Whether the device was inserted or an existing device existed with that ID. + Raises: StoreError: if the device is already in use """ @@ -861,7 +859,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.db.simple_insert( + inserted = await self.db_pool.simple_insert( "devices", values={ "user_id": user_id, @@ -875,7 +873,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.db.simple_select_one_onecol( + hidden = await self.db_pool.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -900,17 +898,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """Delete a device. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the device + device_id: The ID of the device to delete """ - yield self.db.simple_delete_one( + await self.db_pool.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -918,17 +913,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.device_id_exists_cache.invalidate((user_id, device_id)) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the devices + device_ids: The IDs of the devices to delete """ - yield self.db.simple_delete_many( + await self.db_pool.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -938,50 +930,46 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - def update_device(self, user_id, device_id, new_display_name=None): + async def update_device( + self, user_id: str, device_id: str, new_display_name: Optional[str] = None + ) -> None: """Update a device. Only updates the device if it is not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged + user_id: The ID of the user which owns the device + device_id: The ID of the device to update + new_display_name: new displayname for device; None to leave unchanged Raises: StoreError: if the device is not found - Returns: - defer.Deferred """ updates = {} if new_display_name is not None: updates["display_name"] = new_display_name if not updates: - return defer.succeed(None) - return self.db.simple_update_one( + return None + await self.db_pool.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, desc="update_device", ) - def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id - ): + async def update_remote_device_list_cache_entry( + self, user_id: str, device_id: str, content: JsonDict, stream_id: int + ) -> None: """Updates a single device in the cache of a remote user's devicelist. Note: assumes that we are the only thread that can be updating this user's device list. Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list - - Returns: - Deferred[None] + user_id: User to update device list for + device_id: ID of decivice being updated + content: new data on this device + stream_id: the version of the device list """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -991,10 +979,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + content: JsonDict, + stream_id: int, + ) -> None: if content.get("deleted"): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -1002,11 +995,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, + values={"content": json_encoder.encode(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, @@ -1018,7 +1011,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -1028,21 +1021,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache(self, user_id, devices, stream_id): + async def update_remote_device_list_cache( + self, user_id: str, devices: List[dict], stream_id: int + ) -> None: """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's device list. Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list - - Returns: - Deferred[None] + user_id: User to update device list for + devices: list of device objects supplied over federation + stream_id: the version of the device list """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -1050,19 +1042,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id, ) - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self.db.simple_delete_txn( + def _update_remote_device_list_cache_txn( + self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int + ) -> None: + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ { "user_id": user_id, "device_id": content["device_id"], - "content": json.dumps(content), + "content": json_encoder.encode(content), } for content in devices ], @@ -1074,7 +1068,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -1087,20 +1081,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # If we're replacing the remote user's device list cache presumably # we've done a full resync, so we remove the entry that says we need # to resync - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, ) - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): + async def add_device_change_to_streams( + self, user_id: str, device_ids: Collection[str], hosts: List[str] + ): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: - yield self.db.runInteraction( + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: + await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, user_id, @@ -1112,10 +1109,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_device_outbound_poke_to_stream", self._add_device_outbound_poke_to_stream_txn, user_id, @@ -1150,7 +1147,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, min_stream_id) for device_id in device_ids], ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1160,7 +1157,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _add_device_outbound_poke_to_stream_txn( - self, txn, user_id, device_ids, hosts, stream_ids, context, + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + hosts: List[str], + stream_ids: List[str], + context: Dict[str, str], ): for host in hosts: txn.call_after( @@ -1172,7 +1175,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): now = self._clock.time_msec() next_stream_id = iter(stream_ids) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ @@ -1183,7 +1186,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "device_id": device_id, "sent": False, "ts": now, - "opentracing_context": json.dumps(context) + "opentracing_context": json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", } @@ -1192,7 +1195,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): + def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. @@ -1279,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return run_as_background_process( "prune_old_outbound_device_pokes", - self.db.runInteraction, + self.db_pool.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/databases/main/directory.py index e1d1bc3e05..e5060d4c46 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,30 +14,29 @@ # limitations under the License. from collections import namedtuple -from typing import Optional - -from twisted.internet import defer +from typing import Iterable, List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: + """Gets the room_id and server list for a given room_alias Args: - room_alias (RoomAlias) + room_alias: The alias to translate to an ID. Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found + The room alias mapping or None if no association can be found. """ - room_id = yield self.db.simple_select_one_onecol( + room_id = await self.db_pool.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self.db.simple_select_onecol( + servers = await self.db_pool.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore): ) @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self.db.simple_select_onecol( + async def get_aliases_for_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -79,22 +78,24 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + async def create_room_alias_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Iterable[str], + creator: Optional[str] = None, + ) -> None: """ Creates an association between a room alias and room_id/servers Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred + room_alias: The alias to create. + room_id: The target of the alias. + servers: A list of servers through which it may be possible to join the room + creator: Optional user_id of creator. """ def alias_txn(txn): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "room_aliases", { @@ -104,7 +105,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.db.runInteraction( + await self.db_pool.runInteraction( "create_room_alias_association", alias_txn ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - return ret - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.db.runInteraction( + async def delete_room_alias(self, room_alias: RoomAlias) -> str: + room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias): + def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), @@ -160,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -190,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 23f4570c4b..12cecceec2 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,18 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - -from twisted.internet import defer +from typing import Optional from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + async def update_e2e_room_key( + self, user_id, version, room_id, session_id, room_key + ): """Replaces the encrypted E2E room key for a given session in a given backup Args: @@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self.db.simple_update_one( + await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -50,13 +50,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), }, desc="update_e2e_room_key", ) - @defer.inlineCallbacks - def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys(self, user_id, version, room_keys): """Bulk add room keys to a given backup. Args: @@ -77,7 +76,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), } ) log_kv( @@ -89,13 +88,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self.db.simple_insert_many( + await self.db_pool.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. @@ -110,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -125,7 +123,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -148,12 +146,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): "forwarded_count": row["forwarded_count"], # is_verified must be returned to the client as a boolean "is_verified": bool(row["is_verified"]), - "session_data": json.loads(row["session_data"]), + "session_data": db_to_json(row["session_data"]), } return sessions - def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi(self, user_id, version, room_keys): """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -168,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): that we want to query Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -222,20 +220,20 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": row[2], "forwarded_count": row[3], "is_verified": row[4], - "session_data": json.loads(row[5]), + "session_data": db_to_json(row[5]), } return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", @@ -243,8 +241,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -259,7 +258,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred of the deletion transaction + The deletion transaction """ keyvalues = {"user_id": user_id, "version": int(version)} @@ -268,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self.db.simple_delete( + await self.db_pool.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -284,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No current backup version") return row[0] - def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. Args: @@ -294,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: - A deferred dict giving the info metadata for this backup version, with + A dict giving the info metadata for this backup version, with fields including: version(str) algorithm(str) @@ -313,24 +312,24 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, retcols=("version", "algorithm", "auth_data", "etag"), ) - result["auth_data"] = json.loads(result["auth_data"]) + result["auth_data"] = db_to_json(result["auth_data"]) result["version"] = str(result["version"]) if result["etag"] is None: result["etag"] = 0 return result - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @trace - def create_e2e_room_keys_version(self, user_id, info): + async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -339,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): info(dict): the info about the backup version to be created Returns: - A deferred string for the newly created version ID + The newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -353,46 +352,50 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), + "auth_data": json_encoder.encode(info["auth_data"]), }, ) return new_version - return self.db.runInteraction( + return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @trace - def update_e2e_room_keys_version( - self, user_id, version, info=None, version_etag=None - ): + async def update_e2e_room_keys_version( + self, + user_id: str, + version: str, + info: Optional[dict] = None, + version_etag: Optional[int] = None, + ) -> None: """Update a given backup version Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info (dict): the new backup version info to store. If None, then - the backup version info is not updated - version_etag (Optional[int]): etag of the keys in the backup. If - None, then the etag is not updated + user_id: the user whose backup version we're updating + version: the version ID of the backup version we're updating + info: the new backup version info to store. If None, then the backup + version info is not updated. + version_etag: etag of the keys in the backup. If None, then the etag + is not updated. """ updatevalues = {} if info is not None and "auth_data" in info: - updatevalues["auth_data"] = json.dumps(info["auth_data"]) + updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) if version_etag is not None: updatevalues["etag"] = version_etag if updatevalues: - return self.db.simple_update( + await self.db_pool.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, @@ -400,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def delete_e2e_room_keys_version(self, user_id, version=None): + async def delete_e2e_room_keys_version( + self, user_id: str, version: Optional[str] = None + ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting + user_id: the user whose backup version we're deleting + version: Optional. the version ID of the backup version we're deleting If missing, we delete the current backup version info. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present, @@ -421,19 +426,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 20698bfd16..fba3098ea2 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,36 +14,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +import abc +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple -from six import iteritems - -from canonicaljson import encode_canonical_json, json +import attr +from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection -from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.types import Cursor +from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.handlers.e2e_keys import SignatureListItem + + +@attr.s +class DeviceKeyLookupResult: + """The type returned by get_e2e_device_keys_and_signatures""" + + display_name = attr.ib(type=Optional[str]) + + # the key data from e2e_device_keys_json. Typically includes fields like + # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing + # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.) + keys = attr.ib(type=Optional[JsonDict]) + class EndToEndKeyWorkerStore(SQLBaseStore): + async def get_e2e_device_keys_for_federation_query( + self, user_id: str + ) -> Tuple[int, List[JsonDict]]: + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + now_stream_id = self.get_device_stream_token() + + devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + keys = device.keys + if keys: + result["keys"] = keys + + device_display_name = device.display_name + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace - @defer.inlineCallbacks - def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. + async def get_e2e_device_keys_for_cs_api( + self, query_list: List[Tuple[str, Optional[str]]] + ) -> Dict[str, Dict[str, JsonDict]]: + """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -53,45 +96,103 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.db.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, - ) + results = await self.get_e2e_device_keys_and_signatures(query_list) # Build the result structure, un-jsonify the results, and add the # "unsigned" section rv = {} - for user_id, device_keys in iteritems(results): + for user_id, device_keys in results.items(): rv[user_id] = {} - for device_id, device_info in iteritems(device_keys): - r = db_to_json(device_info.pop("key_json")) + for device_id, device_info in device_keys.items(): + r = device_info.keys r["unsigned"] = {} - display_name = device_info["device_display_name"] + display_name = device_info.display_name if display_name is not None: r["unsigned"]["device_display_name"] = display_name - if "signatures" in device_info: - for sig_user_id, sigs in device_info["signatures"].items(): - r.setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) rv[user_id][device_id] = r return rv @trace - def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ): + async def get_e2e_device_keys_and_signatures( + self, + query_list: List[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Fetch a list of device keys + + Any cross-signatures made on the keys by the owner of the device are also + included. + + The cross-signatures are added to the `signatures` field within the `keys` + object in the response. + + Args: + query_list: List of pairs of user_ids and device_ids. Device id can be None + to indicate "all devices for this user" + + include_all_devices: whether to return devices without device keys + + include_deleted_devices: whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. + """ set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) + result = await self.db_pool.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + # get the (user_id, device_id) tuples to look up cross-signatures for + signature_query = ( + (user_id, device_id) + for user_id, dev in result.items() + for device_id, d in dev.items() + if d is not None and d.keys is not None + ) + + for batch in batch_iter(signature_query, 50): + cross_sigs_result = await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_for_devices_txn, + batch, + ) + + # add each cross-signing signature to the correct device in the result dict. + for (user_id, key_id, device_id, signature) in cross_sigs_result: + target_device_result = result[user_id][device_id] + target_device_signatures = target_device_result.keys.setdefault( + "signatures", {} + ) + signing_user_signatures = target_device_signatures.setdefault( + user_id, {} + ) + signing_user_signatures[key_id] = signature + + log_kv(result) + return result + + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Get information on devices from the database + + The results include the device's keys and self-signatures, but *not* any + cross-signing signatures which have been added subsequently (for which, see + get_e2e_device_keys_and_signatures) + """ query_clauses = [] query_params = [] - signature_query_clauses = [] - signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -102,24 +203,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) - signature_query_clause = "target_user_id = ?" - signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) - signature_query_clause += " AND target_device_id = ?" - signature_query_params.append(device_id) - - signature_query_clause += " AND user_id = ?" - signature_query_params.append(user_id) query_clauses.append(query_clause) - signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " - " d.display_name AS device_display_name, " + " d.display_name, " " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" @@ -130,54 +223,53 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) - result = {} - for row in rows: + result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] + for (user_id, device_id, display_name, key_json) in txn: if include_deleted_devices: - deleted_devices.remove((row["user_id"], row["device_id"])) - result.setdefault(row["user_id"], {})[row["device_id"]] = row + deleted_devices.remove((user_id, device_id)) + result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( + display_name, db_to_json(key_json) if key_json else None + ) if include_deleted_devices: for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None - # get signatures on the device - signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + return result - txn.execute(signature_sql, signature_query_params) - rows = self.db.cursor_to_dict(txn) - - # add each cross-signing signature to the correct device in the result dict. - for row in rows: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_user_result = result.get(target_user_id) - if not target_user_result: - continue + def _get_e2e_cross_signing_signatures_for_devices_txn( + self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + ) -> List[Tuple[str, str, str, str]]: + """Get cross-signing signatures for a given list of devices - target_device_result = target_user_result.get(target_device_id) - if not target_device_result: - # note that target_device_result will be None for deleted devices. - continue + Returns signatures made by the owners of the devices. - target_device_signatures = target_device_result.setdefault("signatures", {}) - signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} + Returns: a list of results; each entry in the list is a tuple of + (user_id, key_id, target_device_id, signature). + """ + signature_query_clauses = [] + signature_query_params = [] + + for (user_id, device_id) in device_query: + signature_query_clauses.append( + "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) - signing_user_signatures[signing_key_id] = signature + signature_query_params.extend([user_id, device_id, user_id]) - log_kv(result) - return result + signature_sql = """ + SELECT user_id, key_id, target_device_id, signature + FROM e2e_cross_signing_signatures WHERE %s + """ % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + txn.execute(signature_sql, signature_query_params) + return txn.fetchall() + + async def get_e2e_one_time_keys( + self, user_id: str, device_id: str, key_ids: List[str] + ) -> Dict[Tuple[str, str], str]: """Retrieve a number of one-time keys for a user Args: @@ -187,11 +279,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): retrieve Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key + A map from (algorithm, key_id) to json string for key """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -203,17 +294,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + async def add_e2e_one_time_keys( + self, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: """Insert some new one time keys for a device. Errors if any of the keys already exist. Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ def _add_e2e_one_time_keys(txn): @@ -224,7 +319,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -243,15 +338,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @cached(max_entries=10000) - def count_e2e_one_time_keys(self, user_id, device_id): + async def count_e2e_one_time_keys( + self, user_id: str, device_id: str + ) -> Dict[str, int]: """ Count the number of one time keys the server has for a device Returns: - Dict mapping from algorithm to number of keys for that algorithm. + A mapping from algorithm to number of keys for that algorithm. """ def _count_e2e_one_time_keys(txn): @@ -266,26 +363,27 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result[algorithm] = key_count return result - return self.db.runInteraction( + return await self.db_pool.runInteraction( "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - @defer.inlineCallbacks - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + async def get_e2e_cross_signing_key( + self, user_id: str, key_type: str, from_user_id: Optional[str] = None + ) -> Optional[dict]: """Returns a user's cross-signing key. Args: - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' + user_id: the user whose key is being requested + key_type: the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on + from_user_id: if specified, signatures made by this user on the self-signing key will be included in the result Returns: dict of the key data or None if not found """ - res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) user_keys = res.get(user_id) if not user_keys: return None @@ -303,7 +401,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): list_name="user_ids", num_args=1, ) - def _get_bare_e2e_cross_signing_keys_bulk( + async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: List[str] ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this @@ -311,16 +409,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore): the signatures for the calling user need to be fetched. Args: - user_ids (list[str]): the users whose keys are being requested + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A mapping from user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, or + their user ID will map to None. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, @@ -363,12 +460,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, params) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) for row in rows: user_id = row["user_id"] key_type = row["keytype"] - key = json.loads(row["keydata"]) + key = db_to_json(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key @@ -422,7 +519,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): query_params.extend(item) txn.execute(sql, query_params) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) # and add the signatures to the appropriate keys for row in rows: @@ -451,28 +548,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return keys - @defer.inlineCallbacks - def get_e2e_cross_signing_keys_bulk( - self, user_ids: List[str], from_user_id: str = None - ) -> defer.Deferred: + async def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: Optional[str] = None + ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. Args: - user_ids (list[str]): the users whose keys are being requested - from_user_id (str): if specified, signatures made by this user on + user_ids: the users whose keys are being requested + from_user_id: if specified, signatures made by this user on the self-signing keys will be included in the result Returns: - Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to - key data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A map of user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, + or their user ID will map to None. """ - result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, result, @@ -481,39 +576,73 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return result - def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): - """Return a list of changes from the user signature stream to notify remotes. + async def get_all_user_signature_changes_for_remotes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + Note that the user signature stream represents when a user signs their device with their user-signing key, which is not published to other users or servers, so no `destination` is needed in the returned list. However, this is needed to poke workers. Args: - from_key (int): the stream ID to start at (exclusive) - to_key (int): the stream ID to end at (inclusive) + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. Returns: - Deferred[list[(int,str)]] a list of `(stream_id, user_id)` - """ - sql = """ - SELECT stream_id, from_user_id AS user_id - FROM user_signature_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ - return self.db.execute( + + if last_id == current_id: + return [], current_id, False + + def _get_all_user_signature_changes_for_remotes_txn(txn): + sql = """ + SELECT stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(row[0], (row[1:])) for row in txn] + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( "get_all_user_signature_changes_for_remotes", - None, - sql, - from_key, - to_key, - limit, + _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + async def set_e2e_device_keys( + self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict + ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ @@ -524,7 +653,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self.db.simple_select_one_onecol_txn( + old_key_json = self.db_pool.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -540,7 +669,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -549,10 +678,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"message": "Device keys stored."}) return True - return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + return await self.db_pool.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) - def claim_e2e_one_time_keys(self, query_list): - """Take a list of one time keys out of the database""" + async def claim_e2e_one_time_keys( + self, query_list: Iterable[Tuple[str, str, str]] + ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + """Take a list of one time keys out of the database. + + Args: + query_list: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. + """ @trace def _claim_e2e_one_time_keys(txn): @@ -588,11 +728,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) return result - return self.db.runInteraction( + return await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - def delete_e2e_keys_by_device(self, user_id, device_id): + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn): log_kv( { @@ -601,12 +741,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -615,11 +755,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -629,6 +769,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -654,7 +795,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # We only need to do this for local users, since remote servers should be # responsible for checking this for their own users. if self.hs.is_mine_id(user_id): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "devices", values={ @@ -666,23 +807,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json.dumps(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -690,22 +830,27 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) - def store_e2e_cross_signing_signatures(self, user_id, signatures): + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) + + async def store_e2e_cross_signing_signatures( + self, user_id: str, signatures: "Iterable[SignatureListItem]" + ) -> None: """Stores cross-signing signatures. Args: - user_id (str): the user who made the signatures - signatures (iterable[SignatureListItem]): signatures to add + user_id: the user who made the signatures + signatures: signatures to add """ - return self.db.simple_insert_many( + await self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 24ce8c4330..0b69aa6a94 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,18 +14,17 @@ # limitations under the License. import itertools import logging -from typing import Dict, List, Optional, Set, Tuple - -from six.moves.queue import Empty, PriorityQueue - -from twisted.internet import defer +from queue import Empty, PriorityQueue +from typing import Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError +from synapse.events import EventBase from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter @@ -33,57 +32,51 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain( + self, event_ids: Collection[str], include_given: bool = False + ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) - - def get_auth_chain_ids( - self, - event_ids: List[str], - include_given: bool = False, - ignore_events: Optional[Set[str]] = None, - ): + ) + return await self.get_events_as_list(event_ids) + + async def get_auth_chain_ids( + self, event_ids: Collection[str], include_given: bool = False, + ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: event_ids: state events include_given: include the given events in result - ignore_events: Set of events to exclude from the returned auth - chain. This is useful if the caller will just discard the - given events anyway, and saves us from figuring out their auth - chains if not required. Returns: - list of event_ids + An awaitable which resolve to a list of event_ids """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given, - ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): - if ignore_events is None: - ignore_events = set() - + def _get_auth_chain_ids_txn( + self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool + ) -> List[str]: if include_given: results = set(event_ids) else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE " + base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE " front = set(event_ids) while front: @@ -95,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(base_sql + clause, args) new_front.update(r[0] for r in txn) - new_front -= ignore_events new_front -= results front = new_front @@ -103,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -112,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas chain. Returns: - Deferred[Set[str]] + The set of the difference in auth chains. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -260,13 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_in_room(self, room_id): - return self.db.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db.runInteraction( + async def get_oldest_events_with_depth_in_room(self, room_id): + return await self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -287,17 +274,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return dict(txn) - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): + async def get_max_depth_of(self, event_ids: List[str]) -> int: """Returns the max depth of a set of event IDs Args: - event_ids (list[str]) - - Returns - Deferred[int] + event_ids: The event IDs to calculate the max depth of. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -310,15 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db.simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - - def get_prev_events_for_room(self, room_id: str): + async def get_prev_events_for_room(self, room_id: str) -> List[str]: """ Gets a subset of the current forward extremities in the given room. @@ -326,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas events which refer to hundreds of prev_events. Args: - room_id (str): room_id + room_id: room_id Returns: - Deferred[List[str]]: the event ids of the forward extremites + The event ids of the forward extremities. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -353,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [row[0] for row in txn] - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + async def get_rooms_with_many_extremities( + self, min_count: int, limit: int, room_id_filter: Iterable[str] + ) -> List[str]: """Get the top rooms with at least N extremities. Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results + min_count: The minimum number of extremities + limit: The maximum number of rooms to return. + room_id_filter: room_ids to exclude from the results Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. + At most `limit` room IDs that have at least `min_count` extremities, + sorted by extremity count. """ def _get_rooms_with_many_extremities_txn(txn): @@ -388,28 +365,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self.db.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", desc="get_latest_event_ids_in_room", ) - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. + async def get_min_depth(self, room_id: str) -> int: + """For the given room, get the minimum depth we have seen for it. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self.db.simple_select_one_onecol_txn( + min_depth = self.db_pool.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -419,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None - def get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -427,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas stream_orderings from that point. Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - deferred, which resolves to a list of event_ids + A list of event_ids """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. @@ -447,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) - return self._get_forward_extremeties_for_room(room_id, stream_ordering) + return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -475,31 +454,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` Args: - txn - room_id (str) - event_list (list) - limit (int) + room_id + event_list + limit """ - return ( - self.db.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -521,7 +497,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self.db.simple_select_one_onecol_txn( + depth = self.db_pool.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -551,9 +527,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return event_results - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.db.runInteraction( + async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -561,8 +536,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): @@ -596,17 +570,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_results.reverse() return event_results - @defer.inlineCallbacks - def get_successor_events(self, event_ids): + async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]: """Fetch all events that have the given events as a prev event Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] + event_ids: The events to use as the previous events. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -629,10 +599,10 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventFederationStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -659,13 +629,13 @@ class EventFederationStore(EventFederationWorkerStore): return run_as_background_process( "delete_old_forward_extrem_cache", - self.db.runInteraction, + self.db_pool.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) - def clean_room_for_join(self, room_id): - return self.db.runInteraction( + async def clean_room_for_join(self, room_id): + return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -675,8 +645,7 @@ class EventFederationStore(EventFederationWorkerStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): + async def _background_delete_non_state_event_auth(self, progress, batch_size): def delete_event_auth(txn): target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") @@ -709,17 +678,19 @@ class EventFederationStore(EventFederationWorkerStore): "max_stream_id_exclusive": min_stream_id, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) return min_stream_id >= target_min_stream_id - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) + await self.db_pool.updates._end_background_update( + self.EVENT_AUTH_STATE_ONLY + ) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 0321274de2..5233ed83e2 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -15,17 +15,15 @@ # limitations under the License. import logging +from typing import Dict, List, Optional, Tuple, Union -from six import iteritems - -from canonicaljson import json - -from twisted.internet import defer +import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore -from synapse.storage.database import Database -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -53,14 +51,14 @@ def _serialize_action(actions, is_highlight): else: if actions == DEFAULT_NOTIF_ACTION: return "" - return json.dumps(actions) + return json_encoder.encode(actions) def _deserialize_action(actions, is_highlight): """Custom deserializer for actions. This allows us to "compress" common actions """ if actions: - return json.loads(actions) + return db_to_json(actions) if is_highlight: return DEFAULT_HIGHLIGHT_ACTION @@ -69,7 +67,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -90,114 +88,141 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._rotate_delay = 3 self._rotate_count = 10000 - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): - ret = yield self.db.runInteraction( + @cached(num_args=3, tree=True, max_entries=5000) + async def get_unread_event_push_actions_by_room_for_user( + self, room_id: str, user_id: str, last_read_event_id: Optional[str], + ) -> Dict[str, int]: + """Get the notification count, the highlight count and the unread message count + for a given user in a given room after the given read receipt. + + Note that this function assumes the user to be a current member of the room, + since it's either called by the sync handler to handle joined room entries, or by + the HTTP pusher to calculate the badge of unread joined rooms. + + Args: + room_id: The room to retrieve the counts in. + user_id: The user to retrieve the counts for. + last_read_event_id: The event associated with the latest read receipt for + this user in this room. None if no receipt for this user in this room. + + Returns + A dict containing the counts mentioned earlier in this docstring, + respectively under the keys "notify_count", "highlight_count" and + "unread_count". + """ + return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, user_id, last_read_event_id, ) - return ret def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id + self, txn, room_id, user_id, last_read_event_id, ): - sql = ( - "SELECT stream_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute(sql, (room_id, last_read_event_id)) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} + stream_ordering = None + + if last_read_event_id is not None: + stream_ordering = self.get_stream_id_for_event_txn( + txn, last_read_event_id, allow_none=True, + ) + + if stream_ordering is None: + # Either last_read_event_id is None, or it's an event we don't have (e.g. + # because it's been purged), in which case retrieve the stream ordering for + # the latest membership event from this user in this room (which we assume is + # a join). + event_id = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + retcol="event_id", + ) - stream_ordering = results[0][0] + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 sql = ( - "SELECT count(*)" + "SELECT" + " COUNT(CASE WHEN notif = 1 THEN 1 END)," + " COUNT(CASE WHEN highlight = 1 THEN 1 END)," + " COUNT(CASE WHEN unread = 1 THEN 1 END)" " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" + " WHERE user_id = ?" + " AND room_id = ?" + " AND stream_ordering > ?" ) txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() - notify_count = row[0] if row else 0 + + (notif_count, highlight_count, unread_count) = (0, 0, 0) + + if row: + (notif_count, highlight_count, unread_count) = row txn.execute( """ - SELECT notif_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? - """, + SELECT notif_count, unread_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering), ) - rows = txn.fetchall() - if rows: - notify_count += rows[0][0] + row = txn.fetchone() - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) + if row: + notif_count += row[0] - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 + if row[1] is not None: + # The unread_count column of event_push_summary is NULLable, so we need + # to make sure we don't try increasing the unread counts if it's NULL + # for this row. + unread_count += row[1] - return {"notify_count": notify_count, "highlight_count": highlight_count} + return { + "notify_count": notif_count, + "unread_count": unread_count, + "highlight_count": highlight_count, + } - @defer.inlineCallbacks - def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + async def get_push_action_users_in_range( + self, min_stream_ordering, max_stream_ordering + ): def f(txn): sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" - " stream_ordering >= ? AND stream_ordering <= ?" + " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1" ) txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.db.runInteraction("get_push_action_users_in_range", f) + ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) return ret - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_http( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_http( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ @@ -224,13 +249,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -252,13 +278,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -282,23 +309,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): # one of the subqueries may have hit the limit. return notifs[:limit] - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_email( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_email( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions", "received_ts". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ @@ -324,13 +353,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -352,13 +382,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" + " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -383,62 +414,66 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Now return the first `limit` return notifs[:limit] - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + async def get_if_maybe_push_in_range_for_user( + self, user_id: str, min_stream_ordering: int + ) -> bool: """A fast check to see if there might be something to push for the user since the given stream ordering. May return false positives. Useful to know whether to bother starting a pusher on start up or not. Args: - user_id (str) - min_stream_ordering (int) + user_id + min_stream_ordering Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. + True if there may be push to process, False if there definitely isn't. """ def _get_if_maybe_push_in_range_for_user_txn(txn): sql = """ SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? + WHERE user_id = ? AND stream_ordering > ? AND notif = 1 LIMIT 1 """ txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) - def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging( + self, + event_id: str, + user_id_actions: Dict[str, List[Union[dict, str]]], + count_as_unread: bool, + ) -> None: """Add the push actions for the event to the push action staging area. Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred + event_id + user_id_actions: A mapping of user_id to list of push actions, where + an action can either be a string or dict. + count_as_unread: Whether this event should increment unread counts. """ - if not user_id_actions: return # This is a helper function for generating the necessary tuple that - # can be used to inert into the `event_push_actions_staging` table. + # can be used to insert into the `event_push_actions_staging` table. def _gen_entry(user_id, actions): is_highlight = 1 if _action_has_highlight(actions) else 0 + notif = 1 if "notify" in actions else 0 return ( event_id, # event_id column user_id, # user_id column _serialize_action(actions, is_highlight), # actions column - 1, # notif column + notif, # notif column is_highlight, # highlight column + int(count_as_unread), # unread column ) def _add_push_actions_to_staging_txn(txn): @@ -447,33 +482,29 @@ class EventPushActionsWorkerStore(SQLBaseStore): sql = """ INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight) - VALUES (?, ?, ?, ?, ?) + (event_id, user_id, actions, notif, highlight, unread) + VALUES (?, ?, ?, ?, ?, ?) """ txn.executemany( sql, ( _gen_entry(user_id, actions) - for user_id, actions in iteritems(user_id_actions) + for user_id, actions in user_id_actions.items() ), ) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) - @defer.inlineCallbacks - def remove_push_actions_from_staging(self, event_id): + async def remove_push_actions_from_staging(self, event_id: str) -> None: """Called if we failed to persist the event to ensure that stale push actions don't build up in the DB - - Args: - event_id (str) """ try: - res = yield self.db.simple_delete( + res = await self.db_pool.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -490,7 +521,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.db.runInteraction, + self.db_pool.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -511,7 +542,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago ) - def find_first_stream_ordering_after_ts(self, ts): + async def find_first_stream_ordering_after_ts(self, ts: int) -> int: """Gets the stream ordering corresponding to a given timestamp. Specifically, finds the stream_ordering of the first event that was @@ -520,13 +551,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): relatively slow. Args: - ts (int): timestamp in millis + ts: timestamp in millis Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp + stream ordering of the first event received on/after the timestamp """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -608,38 +638,39 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): + async def get_time_of_last_push_action_before(self, stream_ordering): def f(txn): sql = ( "SELECT e.received_ts" " FROM event_push_actions AS ep" " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ?" + " WHERE ep.stream_ordering > ? AND notif = 1" " ORDER BY ep.stream_ordering ASC" " LIMIT 1" ) txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + result = await self.db_pool.runInteraction( + "get_time_of_last_push_action_before", f + ) return result[0] if result else None class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventPushActionsStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", @@ -652,8 +683,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def get_push_actions_for_user( + async def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False ): def f(txn): @@ -678,24 +708,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " FROM event_push_actions epa, events e" " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" + " AND epa.notif = 1" " ORDER BY epa.stream_ordering DESC" " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) + push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - @defer.inlineCallbacks - def get_latest_push_action_stream_ordering(self): + async def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( "get_latest_push_action_stream_ordering", f ) return result[0] or 0 @@ -749,8 +779,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def _start_rotate_notifs(self): return run_as_background_process("rotate_notifs", self._rotate_notifs) - @defer.inlineCallbacks - def _rotate_notifs(self): + async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -759,12 +788,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.db.runInteraction( + caught_up = await self.db_pool.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: break - yield self.hs.get_clock().sleep(self._rotate_delay) + await self.hs.get_clock().sleep(self._rotate_delay) finally: self._doing_notif_rotation = False @@ -773,7 +802,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -809,7 +838,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -819,49 +848,100 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.notif_count, 0) + upd.notif_count, + coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as notif_count, + SELECT user_id, room_id, count(*) as cnt, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 + AND %s = 1 GROUP BY user_id, room_id ) AS upd LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) - rows = txn.fetchall() + # First get the count of unread messages. + txn.execute( + sql % ("unread_count", "unread"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # We need to merge results from the two requests (the one that retrieves the + # unread count and the one that retrieves the notifications count) into a single + # object because we might not have the same amount of rows in each of them. To do + # this, we use a dict indexed on the user ID and room ID to make it easier to + # populate. + summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] + for row in txn: + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=row[2], + stream_ordering=row[3], + old_user_id=row[4], + notif_count=0, + ) + + # Then get the count of notifications. + txn.execute( + sql % ("notif_count", "notif"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + for row in txn: + if (row[0], row[1]) in summaries: + summaries[(row[0], row[1])].notif_count = row[2] + else: + # Because the rules on notifying are different than the rules on marking + # a message unread, we might end up with messages that notify but aren't + # marked unread, so we might not have a summary for this (user, room) + # tuple to complete. + summaries[(row[0], row[1])] = _EventPushSummary( + unread_count=0, + stream_ordering=row[3], + old_user_id=row[4], + notif_count=row[2], + ) - logger.info("Rotating notifications, handling %d rows", len(rows)) + logger.info("Rotating notifications, handling %d rows", len(summaries)) # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_push_summary", values=[ { - "user_id": row[0], - "room_id": row[1], - "notif_count": row[2], - "stream_ordering": row[3], + "user_id": user_id, + "room_id": room_id, + "notif_count": summary.notif_count, + "unread_count": summary.unread_count, + "stream_ordering": summary.stream_ordering, } - for row in rows - if row[4] is None + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is None ], ) txn.executemany( """ - UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + UPDATE event_push_summary + SET notif_count = ?, unread_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), + ( + ( + summary.notif_count, + summary.unread_count, + summary.stream_ordering, + user_id, + room_id, + ) + for ((user_id, room_id), summary) in summaries.items() + if summary.old_user_id is not None + ), ) txn.execute( @@ -887,3 +967,15 @@ def _action_has_highlight(actions): pass return False + + +@attr.s +class _EventPushSummary: + """Summary of pending event push actions for a given user in a given room. + Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. + """ + + unread_count = attr.ib(type=int) + stream_ordering = attr.ib(type=int) + old_user_id = attr.ib(type=str) + notif_count = attr.ib(type=int) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/databases/main/events.py index a6572571b4..b3d27a2ee7 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -14,45 +14,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import OrderedDict, namedtuple -from functools import wraps -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple - -from six import integer_types, iteritems, text_type -from six.moves import range +from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple import attr -from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer - import synapse.metrics -from synapse.api.constants import ( - EventContentFields, - EventTypes, - Membership, - RelationTypes, -) +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.logging.utils import log_function -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.data_stores.main.search import SearchEntry -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage._base import db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: - from synapse.storage.data_stores.main import DataStore from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -78,27 +65,6 @@ def encode_json(json_object): _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) -def _retry_on_integrity_error(func): - """Wraps a database function so that it gets retried on IntegrityError, - with `delete_existing=True` passed in. - - Args: - func: function that returns a Deferred and accepts a `delete_existing` arg - """ - - @wraps(func) - @defer.inlineCallbacks - def f(self, *args, **kwargs): - try: - res = yield func(self, *args, delete_existing=False, **kwargs) - except self.database_engine.module.IntegrityError: - logger.exception("IntegrityError, retrying.") - res = yield func(self, *args, delete_existing=True, **kwargs) - return res - - return f - - @attr.s(slots=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -123,9 +89,11 @@ class PersistEventsStore: Note: This is not part of the `DataStore` mixin. """ - def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): + def __init__( + self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" + ): self.hs = hs - self.db = db + self.db_pool = db self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() @@ -143,17 +111,14 @@ class PersistEventsStore: hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" - @_retry_on_integrity_error - @defer.inlineCallbacks - def _persist_events_and_state_updates( + async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], backfilled: bool = False, - delete_existing: bool = False, - ): + ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -166,10 +131,9 @@ class PersistEventsStore: new_forward_extremities: Map from room_id to list of event IDs that are the new forward extremities of the room. backfilled - delete_existing Returns: - Deferred: resolves when the events have been persisted + Resolves when the events have been persisted """ # We want to calculate the stream orderings as late as possible, as @@ -189,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) @@ -201,12 +165,11 @@ class PersistEventsStore: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.db.runInteraction( + await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, backfilled=backfilled, - delete_existing=delete_existing, state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) @@ -232,24 +195,23 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, new_state in iteritems(current_state_for_room): + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in iteritems(new_forward_extremeties): + for room_id, latest_event_ids in new_forward_extremeties.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) - @defer.inlineCallbacks - def _get_events_which_are_prevs(self, event_ids): + async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: """Filter the supplied list of event_ids to get those which are prev_events of existing (non-outlier/rejected) events. Args: - event_ids (Iterable[str]): event ids to filter + event_ids: event ids to filter Returns: - Deferred[List[str]]: filtered event ids + Filtered event ids """ results = [] @@ -271,17 +233,16 @@ class PersistEventsStore: ) txn.execute(sql + clause, args) - results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) + results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) return results - @defer.inlineCallbacks - def _get_prevs_before_rejected(self, event_ids): + async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]: """Get soft-failed ancestors to remove from the extremities. Given a set of events, find all those that have been soft-failed or @@ -293,11 +254,11 @@ class PersistEventsStore: are separated by soft failed events. Args: - event_ids (Iterable[str]): Events to find prev events for. Note - that these must have already been persisted. + event_ids: Events to find prev events for. Note that these must have + already been persisted. Returns: - Deferred[set[str]] + The previous events. """ # The set of event_ids to return. This includes all soft-failed events @@ -332,13 +293,13 @@ class PersistEventsStore: if prev_event_id in existing_prevs: continue - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) @@ -350,7 +311,6 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool, - delete_existing: bool = False, state_delta_for_room: Dict[str, DeltaState] = {}, new_forward_extremeties: Dict[str, List[str]] = {}, ): @@ -402,13 +362,6 @@ class PersistEventsStore: # From this point onwards the events are only events that we haven't # seen before. - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - self._store_event_txn(txn, events_and_contexts=events_and_contexts) # Insert into event_to_state_groups. @@ -420,7 +373,7 @@ class PersistEventsStore: # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -461,7 +414,7 @@ class PersistEventsStore: state_delta_by_room: Dict[str, DeltaState], stream_id: int, ): - for room_id, delta_state in iteritems(state_delta_by_room): + for room_id, delta_state in state_delta_by_room.items(): to_delete = delta_state.to_delete to_insert = delta_state.to_insert @@ -483,7 +436,7 @@ class PersistEventsStore: """ txn.execute(sql, (stream_id, room_id)) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, ) else: @@ -545,7 +498,7 @@ class PersistEventsStore: """, [ (room_id, key[0], key[1], ev_id, ev_id) - for key, ev_id in iteritems(to_insert) + for key, ev_id in to_insert.items() ], ) @@ -626,12 +579,12 @@ class PersistEventsStore: txn.execute(sql, (room_id, EventTypes.Create, "")) row = txn.fetchone() if row: - event_json = json.loads(row[0]) + event_json = db_to_json(row[0]) content = event_json.get("content", {}) creator = content.get("creator") room_version_id = content.get("room_version", RoomVersions.V1.identifier) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="rooms", keyvalues={"room_id": room_id}, @@ -642,20 +595,20 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn, new_forward_extremities, max_stream_order ): - for room_id, new_extrem in iteritems(new_forward_extremities): - self.db.simple_delete_txn( + for room_id, new_extrem in new_forward_extremities.items(): + self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after( self.store.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ {"event_id": ev_id, "room_id": room_id} - for room_id, new_extrem in iteritems(new_forward_extremities) + for room_id, new_extrem in new_forward_extremities.items() for ev_id in new_extrem ], ) @@ -663,7 +616,7 @@ class PersistEventsStore: # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -672,7 +625,7 @@ class PersistEventsStore: "event_id": event_id, "stream_ordering": max_stream_order, } - for room_id, new_extrem in iteritems(new_forward_extremities) + for room_id, new_extrem in new_forward_extremities.items() for event_id in new_extrem ], ) @@ -727,7 +680,7 @@ class PersistEventsStore: event.depth, depth_updates.get(event.room_id, event.depth) ) - for room_id, depth in iteritems(depth_updates): + for room_id, depth in depth_updates.items(): self._update_min_depth_for_room_txn(txn, room_id, depth) def _update_outliers_txn(self, txn, events_and_contexts): @@ -787,7 +740,7 @@ class PersistEventsStore: # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -806,40 +759,6 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] - @classmethod - def _delete_existing_rows_txn(cls, txn, events_and_contexts): - if not events_and_contexts: - # nothing to do here - return - - logger.info("Deleting existing") - - for table in ( - "events", - "event_auth", - "event_json", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_to_state_groups", - "local_invites", - "state_events", - "rejections", - "redactions", - "room_memberships", - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts], - ) - - for table in ("event_push_actions",): - txn.executemany( - "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], - ) - def _store_event_txn(self, txn, events_and_contexts): """Insert new events into the event and event_json tables @@ -859,7 +778,7 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_json", values=[ @@ -876,7 +795,7 @@ class PersistEventsStore: ], ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="events", values=[ @@ -893,8 +812,7 @@ class PersistEventsStore: "received_ts": self._clock.time_msec(), "sender": event.sender, "contains_url": ( - "url" in event.content - and isinstance(event.content["url"], text_type) + "url" in event.content and isinstance(event.content["url"], str) ), } for event, _ in events_and_contexts @@ -906,7 +824,7 @@ class PersistEventsStore: # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -1048,7 +966,9 @@ class PersistEventsStore: state_values.append(vals) - self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) + self.db_pool.simple_insert_many_txn( + txn, table="state_events", values=state_values + ) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1079,7 +999,7 @@ class PersistEventsStore: ) txn.execute(sql + clause, args) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: @@ -1099,7 +1019,7 @@ class PersistEventsStore: # invalidate the cache for the redacted event txn.call_after(self.store._invalidate_get_event_cache, event.redacts) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="redactions", values={ @@ -1122,7 +1042,7 @@ class PersistEventsStore: room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self.db.simple_insert_many_txn( + return self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1144,7 +1064,7 @@ class PersistEventsStore: event_id (str): The event ID the expiry timestamp is associated with. expiry_ts (int): The timestamp at which to expire (delete) the event. """ - return self.db.simple_insert_txn( + return self.db_pool.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, @@ -1168,12 +1088,14 @@ class PersistEventsStore: } ) - self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.db_pool.simple_insert_many_txn( + txn, table="event_reference_hashes", values=vals + ) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1201,65 +1123,27 @@ class PersistEventsStore: (event.state_key,), ) - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self.db.simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - # We also update the `local_current_membership` table with - # latest invite info. This will usually get updated by the - # `current_state_events` handling, unless its an outlier. - if event.internal_metadata.is_outlier(): - # This should only happen for out of band memberships, so - # we add a paranoia check. - assert event.internal_metadata.is_out_of_band_membership() - - self.db.simple_upsert_txn( - txn, - table="local_current_membership", - keyvalues={ - "room_id": event.room_id, - "user_id": event.state_key, - }, - values={ - "event_id": event.event_id, - "membership": event.membership, - }, - ) + # We update the local_current_membership table only if the event is + # "current", i.e., its something that has just happened. + # + # This will usually get updated by the `current_state_events` handling, + # unless its an outlier, and an outlier is only "current" if it's an "out of + # band membership", like a remote invite or a rejection of a remote invite. + if ( + self.is_mine_id(event.state_key) + and not backfilled + and event.internal_metadata.is_outlier() + and event.internal_metadata.is_out_of_band_membership() + ): + self.db_pool.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": event.room_id, "user_id": event.state_key}, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) def _handle_event_relations(self, txn, event): """Handles inserting relation data during peristence of events @@ -1289,7 +1173,7 @@ class PersistEventsStore: aggregation_key = relation.get("key") - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="event_relations", values={ @@ -1317,7 +1201,7 @@ class PersistEventsStore: redacted_event_id (str): The event that was redacted. """ - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) @@ -1345,15 +1229,15 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), integer_types) + and not isinstance(event.content.get("min_lifetime"), int) ) or ( "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), integer_types) + and not isinstance(event.content.get("max_lifetime"), int) ): # Ignore the event if one of the value isn't an integer. return - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -1412,9 +1296,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight + topological_ordering, notif, highlight, unread ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread FROM event_push_actions_staging WHERE event_id = ? """ @@ -1434,7 +1318,7 @@ class PersistEventsStore: ) for event, _ in events_and_contexts: - user_ids = self.db.simple_select_onecol_txn( + user_ids = self.db_pool.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -1466,7 +1350,7 @@ class PersistEventsStore: ) def _store_rejections_txn(self, txn, event_id, reason): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="rejections", values={ @@ -1492,16 +1376,16 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) + for event_id, state_group_id in state_groups.items() ], ) - for event_id, state_group_id in iteritems(state_groups): + for event_id, state_group_id in state_groups.items(): txn.call_after( self.store._get_state_group_for_event.prefill, (event_id,), @@ -1514,7 +1398,7 @@ class PersistEventsStore: if min_depth is not None and depth >= min_depth: return - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -1526,7 +1410,7 @@ class PersistEventsStore: For the given event, update the event edges table and forward and backward extremities tables. """ - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_edges", values=[ @@ -1590,31 +1474,3 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) - - async def locally_reject_invite(self, user_id: str, room_id: str) -> int: - """Mark the invite has having been rejected even though we failed to - create a leave event for it. - """ - - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - # We also clear this entry from `local_current_membership`. - # Ideally we'd point to a leave event, but we don't have one, so - # nevermind. - self.db.simple_delete_txn( - txn, - table="local_current_membership", - keyvalues={"room_id": room_id, "user_id": user_id}, - ) - - with self._stream_id_gen.get_next() as stream_ordering: - await self.db.runInteraction("locally_reject_invite", f, stream_ordering) - - return stream_ordering diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index f54c8b1ee0..e53c6373a8 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -15,15 +15,9 @@ import logging -from six import text_type - -from canonicaljson import json - -from twisted.internet import defer - from synapse.api.constants import EventContentFields -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool logger = logging.getLogger(__name__) @@ -34,18 +28,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -56,7 +50,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -65,16 +59,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): psql_only=True, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -82,15 +76,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="have_censored", ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "redactions_have_censored_ts_idx", index_name="redactions_have_censored_ts", table="redactions", @@ -98,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT have_censored", ) - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -127,13 +120,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): for row in rows: try: event_id = row[1] - event_json = json.loads(row[2]) + event_json = db_to_json(row[2]) sender = event_json["sender"] content = event_json["content"] contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], text_type) + contains_url &= isinstance(content["url"], str) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -153,25 +146,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): "rows_inserted": rows_inserted + len(rows), } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): + async def _background_reindex_origin_server_ts(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -199,7 +191,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self.db.simple_select_many_txn( + ev_rows = self.db_pool.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -210,7 +202,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): for row in ev_rows: event_id = row["event_id"] - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) try: origin_server_ts = event_json["origin_server_ts"] except (KeyError, AttributeError): @@ -232,25 +224,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): "rows_inserted": rows_inserted + len(rows_to_update), } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_ORIGIN_SERVER_TS_NAME ) return result - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): + async def _cleanup_extremities_bg_update(self, progress, batch_size): """Background update to clean out extremities that should have been deleted previously. @@ -319,7 +310,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): soft_failed = False if metadata: - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) @@ -360,7 +351,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): graph[event_id] = {prev_event_id} - soft_failed = json.loads(metadata).get("soft_failed") + soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) else: @@ -378,7 +369,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): to_delete.intersection_update(original_set) - deleted = self.db.simple_delete_many_txn( + deleted = self.db_pool.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -394,7 +385,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="events", column="event_id", @@ -408,7 +399,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -418,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(original_set) - num_handled = yield self.db.runInteraction( + num_handled = await self.db_pool.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) return num_handled - @defer.inlineCallbacks - def _redactions_received_ts(self, progress, batch_size): + async def _redactions_received_ts(self, progress, batch_size): """Handles filling out the `received_ts` column in redactions. """ last_event_id = progress.get("last_event_id", "") @@ -478,23 +468,22 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) return len(rows) - count = yield self.db.runInteraction( + count = await self.db_pool.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) if not count: - yield self.db.updates._end_background_update("redactions_received_ts") + await self.db_pool.updates._end_background_update("redactions_received_ts") return count - @defer.inlineCallbacks - def _event_fix_redactions_bytes(self, progress, batch_size): + async def _event_fix_redactions_bytes(self, progress, batch_size): """Undoes hex encoded censored redacted event JSON. """ @@ -515,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self.db.updates._end_background_update("event_fix_redactions_bytes") + await self.db_pool.updates._end_background_update("event_fix_redactions_bytes") return 1 - @defer.inlineCallbacks - def _event_store_labels(self, progress, batch_size): + async def _event_store_labels(self, progress, batch_size): """Background update handler which will store labels for existing events.""" last_event_id = progress.get("last_event_id", "") @@ -545,9 +533,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): last_row_event_id = "" for (event_id, event_json_raw) in results: try: - event_json = json.loads(event_json_raw) + event_json = db_to_json(event_json_raw) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -573,17 +561,17 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): nbrows += 1 last_row_event_id = event_id - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) return nbrows - num_rows = yield self.db.runInteraction( + num_rows = await self.db_pool.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) if not num_rows: - yield self.db.updates._end_background_update("event_store_labels") + await self.db_pool.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 213d69100a..a7a73cc3d8 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,10 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload -from canonicaljson import json from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -33,16 +33,18 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.replication.tcp.streams import BackfillStream +from synapse.replication.tcp.streams.events import EventsStream +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -73,17 +75,14 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.writers.events == hs.get_instance_name(): # We are the process in charge of generating stream ids for events, # so instantiate ID generators based on the database self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], + db_conn, "events", "stream_ordering", ) self._backfill_id_gen = StreamIdGenerator( db_conn, @@ -113,70 +112,59 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_ongoing = 0 def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": - self._stream_id_gen.advance(token) - elif stream_name == "backfill": - self._backfill_id_gen.advance(-token) + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", desc="get_received_ts", ) - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... - return self.db.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... - @defer.inlineCallbacks - def get_event( + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -184,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -209,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -232,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -258,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -269,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -297,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -308,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -343,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -409,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -421,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -437,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -455,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -539,7 +524,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events } - row_dict = self.db.new_transaction( + row_dict = self.db_pool.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -563,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -578,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -586,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -612,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = json.loads(row["json"]) - internal_metadata = json.loads(row["internal_metadata"]) + # If the event or metadata cannot be parsed, log the error and act + # as if the event is unknown. + try: + d = db_to_json(row["json"]) + except ValueError: + logger.error("Unable to parse json from event: %s", event_id) + continue + try: + internal_metadata = db_to_json(row["internal_metadata"]) + except ValueError: + logger.error( + "Unable to parse internal_metadata from event: %s", event_id + ) + continue format_version = row["format_version"] if format_version is None: @@ -624,24 +620,43 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row["room_version_id"] if not room_version_id: - # this should only happen for out-of-band membership events - if not internal_metadata.get("out_of_band_membership"): - logger.warning( - "Room %s for event %s is unknown", d["room_id"], event_id + # this should only happen for out-of-band membership events which + # arrived before #6983 landed. For all other events, we should have + # an entry in the 'rooms' table. + # + # However, the 'out_of_band_membership' flag is unreliable for older + # invites, so just accept it for all membership events. + # + if d["type"] != EventTypes.Member: + raise Exception( + "Room %s for event %s is unknown" % (d["room_id"], event_id) ) - continue - # take a wild stab at the room version based on the event format + # so, assuming this is an out-of-band-invite that arrived before #6983 + # landed, we know that the room version must be v5 or earlier (because + # v6 hadn't been invented at that point, so invites from such rooms + # would have been rejected.) + # + # The main reason we need to know the room version here (other than + # choosing the right python Event class) is in case the event later has + # to be redacted - and all the room versions up to v5 used the same + # redaction algorithm. + # + # So, the following approximations should be adequate. + if format_version == EventFormatVersions.V1: + # if it's event format v1 then it must be room v1 or v2 room_version = RoomVersions.V1 elif format_version == EventFormatVersions.V2: + # if it's event format v2 then it must be room v3 room_version = RoomVersions.V3 else: + # if it's event format v3 then it must be room v4 or v5 room_version = RoomVersions.V5 else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) if not room_version: - logger.error( + logger.warning( "Event %s in room %s has unknown room version %s", event_id, d["room_id"], @@ -688,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -698,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -716,12 +730,12 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", self.db.runWithConnection, self._do_fetch + "fetch_events", self.db_pool.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -809,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - def _maybe_redact_event_row(self, original_ev, redactions, event_map): + def _maybe_redact_event_row( + self, + original_ev: EventBase, + redactions: Iterable[str], + event_map: Dict[str, EventBase], + ) -> Optional[EventBase]: """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. + original_ev: The original event. + redactions: list of event ids of potential redaction events + event_map: other events which have been fetched, in which we can + look up the redaaction events. Map from event id to event. Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. + If the event should be redacted, a pruned event object. Otherwise, None. """ if original_ev.type == "m.room.create": # we choose to ignore redactions of m.room.create events. @@ -880,12 +898,11 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -896,15 +913,14 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -920,41 +936,11 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. @@ -964,24 +950,23 @@ class EventsWorkerStore(SQLBaseStore): row = txn.fetchone() return row[0] if row else 0 - def get_current_state_event_counts(self, room_id): + async def get_current_state_event_counts(self, room_id: str) -> int: """ Gets the current number of state events in a room. Args: - room_id (str) + room_id: The room ID to query. Returns: - Deferred[int] + The current number of state events. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -992,9 +977,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1010,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() - def get_all_new_forward_event_rows(self, last_id, current_id, limit): + async def get_all_new_forward_event_rows( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -1018,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore): current_id: the maximum stream_id to return up to limit: the maximum number of rows to return - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1039,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) - def get_ex_outlier_stream_rows(self, last_id, current_id): + async def get_ex_outlier_stream_rows( + self, last_id: int, current_id: int + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1073,13 +1062,36 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + async def get_all_new_backfill_event_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for backfill replication stream, including all new + backfilled events and events that have gone from being outliers to not. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_new_backfill_event_rows(txn): sql = ( @@ -1094,10 +1106,12 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() + new_event_updates = [(row[0], row[1:]) for row in txn] + limited = False if len(new_event_updates) == limit: upper_bound = new_event_updates[-1][0] + limited = True else: upper_bound = current_id @@ -1114,11 +1128,15 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) + new_event_updates.extend((row[0], row[1:]) for row in txn) + + if len(new_event_updates) >= limit: + upper_bound = new_event_updates[-1][0] + limited = True - return new_event_updates + return new_event_updates, upper_bound, limited - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @@ -1166,7 +1184,7 @@ class EventsWorkerStore(SQLBaseStore): # we need to make sure that, for every stream id in the results, we get *all* # the rows with that stream id. - rows = await self.db.runInteraction( + rows = await self.db_pool.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) # type: List[Tuple] @@ -1189,103 +1207,12 @@ class EventsWorkerStore(SQLBaseStore): # stream id. let's run the query again, without a row limit, but for # just one stream id. to_token += 1 - rows = await self.db.runInteraction( + rows = await self.db_pool.runInteraction( "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) return rows, to_token, True - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ @@ -1293,9 +1220,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1307,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_next_event_to_expire(self): + async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. - Returns: Deferred[Optional[Tuple[str, int]]] + Returns: A tuple containing the event ID as its first element and an expiry timestamp as its second one, if there's at least one row in the event_expiry table. None otherwise. @@ -1327,17 +1254,6 @@ class EventsWorkerStore(SQLBaseStore): return txn.fetchone() - return self.db.runInteraction( + return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/databases/main/filtering.py index 342d6622a4..d2f5b9a502 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,12 +17,13 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import JsonDict +from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): + @cached(num_args=2) + async def get_user_filter(self, user_localpart, filter_id): # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -30,7 +31,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.db.simple_select_one_onecol( + def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/databases/main/group_server.py index 787ce9e584..1cbf31f52d 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,12 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json - -from twisted.internet import defer +from typing import Any, Dict, List, Optional, Tuple, Union from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict +from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = "" class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -44,31 +44,35 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group", ) - def get_users_in_group(self, group_id, include_private=False): + async def get_users_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Any]]: # TODO: Pagination keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True - return self.db.simple_select_list( + return await self.db_pool.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), desc="get_users_in_group", ) - def get_invited_users_in_group(self, group_id): + async def get_invited_users_in_group(self, group_id: str) -> List[str]: # TODO: Pagination - return self.db.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id: str, include_private: bool = False): + async def get_rooms_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Union[str, bool]]]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -77,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: + A list of dictionaries, each in the form of: { "room_id": "!a_room_id:example.com", # The ID of the room @@ -115,11 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore): for room_id, is_public in txn ] - return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + return await self.db_pool.runInteraction( + "get_rooms_in_group", _get_rooms_in_group_txn + ) - def get_rooms_for_summary_by_category( + async def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, - ): + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request Args: @@ -127,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: + A tuple containing: * A list of dictionaries with the keys: * "room_id": str, the room ID @@ -195,7 +200,7 @@ class GroupServerWorkerStore(SQLBaseStore): categories = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -203,13 +208,12 @@ class GroupServerWorkerStore(SQLBaseStore): return rooms, categories - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db.simple_select_list( + async def get_group_categories(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -219,27 +223,25 @@ class GroupServerWorkerStore(SQLBaseStore): return { row["category_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db.simple_select_one( + async def get_group_category(self, group_id, category_id): + category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), desc="get_group_category", ) - category["profile"] = json.loads(category["profile"]) + category["profile"] = db_to_json(category["profile"]) return category - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db.simple_select_list( + async def get_group_roles(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -249,43 +251,42 @@ class GroupServerWorkerStore(SQLBaseStore): return { row["role_id"]: { "is_public": row["is_public"], - "profile": json.loads(row["profile"]), + "profile": db_to_json(row["profile"]), } for row in rows } - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db.simple_select_one( + async def get_group_role(self, group_id, role_id): + role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), desc="get_group_role", ) - role["profile"] = json.loads(role["profile"]) + role["profile"] = db_to_json(role["profile"]) return role - def get_local_groups_for_room(self, room_id): + async def get_local_groups_for_room(self, room_id: str) -> List[str]: """Get all of the local group that contain a given room Args: - room_id (str): The ID of a room + room_id: The ID of a room Returns: - Deferred[list[str]]: A twisted.Deferred containing a list of group ids - containing this room + A list of group ids containing this room """ - return self.db.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", desc="get_local_groups_for_room", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request - Returns ([users], [roles]) + Returns: + ([users], [roles]) """ def _get_users_for_summary_txn(txn): @@ -331,7 +332,7 @@ class GroupServerWorkerStore(SQLBaseStore): roles = { row[0]: { "is_public": row[1], - "profile": json.loads(row[2]), + "profile": db_to_json(row[2]), "order": row[3], } for row in txn @@ -339,21 +340,24 @@ class GroupServerWorkerStore(SQLBaseStore): return users, roles - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) - def is_user_in_group(self, user_id, group_id): - return self.db.simple_select_one_onecol( + async def is_user_in_group(self, user_id: str, group_id: str) -> bool: + result = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) + ) + return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -361,10 +365,12 @@ class GroupServerWorkerStore(SQLBaseStore): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -372,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group(self, group_id, user_id): """Get a dict describing the membership of a user in a group. Example if joined: @@ -383,11 +389,12 @@ class GroupServerWorkerStore(SQLBaseStore): "is_privileged": False, } - Returns an empty dict if the user is not join/invite/etc + Returns: + An empty dict if the user is not join/invite/etc """ def _get_users_membership_in_group_txn(txn): - row = self.db.simple_select_one_txn( + row = self.db_pool.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -402,7 +409,7 @@ class GroupServerWorkerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self.db.simple_select_one_onecol_txn( + row = self.db_pool.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -415,21 +422,21 @@ class GroupServerWorkerStore(SQLBaseStore): return {} - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) - def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: """Get all groups a user is publicising """ - return self.db.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", desc="get_publicised_groups_for_user", ) - def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ @@ -439,18 +446,17 @@ class GroupServerWorkerStore(SQLBaseStore): WHERE valid_until_ms <= ? """ txn.execute(sql, (valid_until_ms,)) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.db.simple_select_one( + row = await self.db_pool.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -460,19 +466,19 @@ class GroupServerWorkerStore(SQLBaseStore): now = int(self._clock.time_msec()) if row and now < row["valid_until_ms"]: - return json.loads(row["attestation_json"]) + return db_to_json(row["attestation_json"]) return None - def get_joined_groups(self, user_id): - return self.db.simple_select_onecol( + async def get_joined_groups(self, user_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", desc="get_joined_groups", ) - def get_all_groups_for_user(self, user_id, now_token): + async def get_all_groups_for_user(self, user_id, now_token): def _get_all_groups_for_user_txn(txn): sql = """ SELECT group_id, type, membership, u.content @@ -487,22 +493,22 @@ class GroupServerWorkerStore(SQLBaseStore): "group_id": row[0], "type": row[1], "membership": row[2], - "content": json.loads(row[3]), + "content": db_to_json(row[3]), } for row in txn ] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) - def get_groups_changes_for_user(self, user_id, from_token, to_token): + async def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( user_id, from_token ) if not has_changed: - return defer.succeed([]) + return [] def _get_groups_changes_for_user_txn(txn): sql = """ @@ -517,22 +523,44 @@ class GroupServerWorkerStore(SQLBaseStore): "group_id": group_id, "membership": membership, "type": gtype, - "content": json.loads(content_json), + "content": db_to_json(content_json), } for group_id, membership, gtype, content_json in txn ] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) - def get_all_groups_changes(self, from_token, to_token, limit): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_any_entity_changed( - from_token - ) + async def get_all_groups_changes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + last_id = int(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + if not has_changed: - return defer.succeed([]) + return [], current_id, False def _get_all_groups_changes_txn(txn): sql = """ @@ -541,34 +569,61 @@ class GroupServerWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) - return [ - (stream_id, group_id, user_id, gtype, json.loads(content_json)) + txn.execute(sql, (last_id, current_id, limit)) + updates = [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) for stream_id, group_id, user_id, gtype, content_json in txn ] - return self.db.runInteraction( + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) class GroupServerStore(GroupServerWorkerStore): - def set_group_join_policy(self, group_id, join_policy): + async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: """Set the join policy of a group. join_policy can be one of: * "invite" * "open" """ - return self.db.simple_update_one( + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, desc="set_group_join_policy", ) - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db.runInteraction( + async def add_room_to_summary( + self, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) room's entry in summary. + + Args: + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -579,20 +634,28 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): + self, + txn, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: """Add (or update) room's entry in summary. Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. + txn + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public """ - room_in_group = self.db.simple_select_one_onecol_txn( + room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -605,7 +668,7 @@ class GroupServerStore(GroupServerWorkerStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self.db.simple_select_one_onecol_txn( + cat_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -616,7 +679,7 @@ class GroupServerStore(GroupServerWorkerStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self.db.simple_select_one_onecol_txn( + cat_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -636,7 +699,7 @@ class GroupServerStore(GroupServerWorkerStore): (group_id, category_id, group_id, category_id), ) - existing = self.db.simple_select_one_txn( + existing = self.db_pool.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -669,7 +732,7 @@ class GroupServerStore(GroupServerWorkerStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -683,7 +746,7 @@ class GroupServerStore(GroupServerWorkerStore): if is_public is None: is_public = True - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -695,11 +758,13 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - def remove_room_from_summary(self, group_id, room_id, category_id): + async def remove_room_from_summary( + self, group_id: str, room_id: str, category_id: str + ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.db.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -709,7 +774,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_room_from_summary", ) - def upsert_group_category(self, group_id, category_id, profile, is_public): + async def upsert_group_category( + self, + group_id: str, + category_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/update room category for group """ insertion_values = {} @@ -718,14 +789,14 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True else: update_values["is_public"] = is_public - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -733,14 +804,20 @@ class GroupServerStore(GroupServerWorkerStore): desc="upsert_group_category", ) - def remove_group_category(self, group_id, category_id): - return self.db.simple_delete( + async def remove_group_category(self, group_id: str, category_id: str) -> int: + return await self.db_pool.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", ) - def upsert_group_role(self, group_id, role_id, profile, is_public): + async def upsert_group_role( + self, + group_id: str, + role_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/remove user role """ insertion_values = {} @@ -749,14 +826,14 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True else: update_values["is_public"] = is_public - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -764,15 +841,34 @@ class GroupServerStore(GroupServerWorkerStore): desc="upsert_group_role", ) - def remove_group_role(self, group_id, role_id): - return self.db.simple_delete( + async def remove_group_role(self, group_id: str, role_id: str) -> int: + return await self.db_pool.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db.runInteraction( + async def add_user_to_summary( + self, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) user's entry in summary. + + Args: + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -783,20 +879,28 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public + self, + txn, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], ): """Add (or update) user's entry in summary. Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. + txn + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public """ - user_in_group = self.db.simple_select_one_onecol_txn( + user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -809,7 +913,7 @@ class GroupServerStore(GroupServerWorkerStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self.db.simple_select_one_onecol_txn( + role_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -820,7 +924,7 @@ class GroupServerStore(GroupServerWorkerStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self.db.simple_select_one_onecol_txn( + role_exists = self.db_pool.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -840,7 +944,7 @@ class GroupServerStore(GroupServerWorkerStore): (group_id, role_id, group_id, role_id), ) - existing = self.db.simple_select_one_txn( + existing = self.db_pool.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -869,7 +973,7 @@ class GroupServerStore(GroupServerWorkerStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -883,7 +987,7 @@ class GroupServerStore(GroupServerWorkerStore): if is_public is None: is_public = True - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_summary_users", values={ @@ -895,50 +999,51 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - def remove_user_from_summary(self, group_id, user_id, role_id): + async def remove_user_from_summary( + self, group_id: str, user_id: str, role_id: str + ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.db.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", ) - def add_group_invite(self, group_id, user_id): + async def add_group_invite(self, group_id: str, user_id: str) -> None: """Record that the group server has invited a user """ - return self.db.simple_insert( + await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", ) - def add_user_to_group( + async def add_user_to_group( self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): + group_id: str, + user_id: str, + is_admin: bool = False, + is_public: bool = True, + local_attestation: dict = None, + remote_attestation: dict = None, + ) -> None: """Add a user to the group server. Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote + group_id + user_id + is_admin + is_public + local_attestation: The attestation the GS created to give to the remote server. Optional if the user and group are on the same server + remote_attestation: The attestation given to GS by remote server. + Optional if the user and group are on the same server """ def _add_user_to_group_txn(txn): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_users", values={ @@ -949,14 +1054,14 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -966,137 +1071,145 @@ class GroupServerStore(GroupServerWorkerStore): }, ) if remote_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_remote", values={ "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) - return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) + await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - def remove_user_from_group(self, group_id, user_id): + async def remove_user_from_group(self, group_id: str, user_id: str) -> None: def _remove_user_from_group_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) - def change_user_admin_in_group(self, group_id, user_id, is_admin): - return self.db.simple_update( + async def change_user_admin_in_group( + self, group_id: str, user_id: str, is_admin: bool + ) -> int: + return await self.db_pool.simple_update( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_admin": is_admin}, desc="change_user_admin_in_group" ) - def add_room_to_group(self, group_id, room_id, is_public): - return self.db.simple_insert( + async def add_room_to_group( + self, group_id: str, room_id: str, is_public: bool + ) -> None: + await self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db.simple_update( + async def update_room_in_group_visibility( + self, group_id: str, room_id: str, is_public: bool + ) -> int: + return await self.db_pool.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, desc="update_room_in_group_visibility", ) - def remove_room_from_group(self, group_id, room_id): + async def remove_room_from_group(self, group_id: str, room_id: str) -> None: def _remove_room_from_group_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) - def update_group_publicity(self, group_id, user_id, publicise): + async def update_group_publicity( + self, group_id: str, user_id: str, publicise: bool + ) -> None: """Update whether the user is publicising their membership of the group """ - return self.db.simple_update_one( + await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, desc="update_group_publicity", ) - @defer.inlineCallbacks - def register_user_group_membership( + async def register_user_group_membership( self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + group_id: str, + user_id: str, + membership: str, + is_admin: bool = False, + content: JsonDict = {}, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, + is_publicised: bool = False, + ) -> int: """Registers that a local user is a member of a (local or remote) group. Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter + group_id: The group the member is being added to. + user_id: THe user ID to add to the group. + membership: The type of group membership. + is_admin: Whether the user should be added as a group admin. + content: Content of the membership, e.g. includes the inviter if the user has been invited. - local_attestation (dict): If remote group then store the fact that we + local_attestation: If remote group then store the fact that we have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote + remote_attestation: If remote group then store the remote attestation from the group, else None. + is_publicised: Whether this should be publicised. """ def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="local_group_membership", values={ @@ -1105,11 +1218,11 @@ class GroupServerStore(GroupServerWorkerStore): "is_admin": is_admin, "membership": membership, "is_publicised": is_publicised, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="local_group_updates", values={ @@ -1117,7 +1230,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "type": "membership", - "content": json.dumps( + "content": json_encoder.encode( {"membership": membership, "content": content} ), }, @@ -1128,7 +1241,7 @@ class GroupServerStore(GroupServerWorkerStore): if membership == "join": if local_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -1138,23 +1251,23 @@ class GroupServerStore(GroupServerWorkerStore): }, ) if remote_attestation: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="group_attestations_remote", values={ "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) else: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -1162,19 +1275,18 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db.runInteraction( + with await self._group_updates_id_gen.get_next() as next_id: + res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, ) return res - @defer.inlineCallbacks - def create_group( + async def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db.simple_insert( + ) -> None: + await self.db_pool.simple_insert( table="groups", values={ "group_id": group_id, @@ -1187,48 +1299,51 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db.simple_update_one( + async def update_group_profile(self, group_id, profile): + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, desc="update_group_profile", ) - def update_attestation_renewal(self, group_id, user_id, attestation): + async def update_attestation_renewal( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that we have renewed """ - return self.db.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, desc="update_attestation_renewal", ) - def update_remote_attestion(self, group_id, user_id, attestation): + async def update_remote_attestion( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that a remote has renewed """ - return self.db.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), + "attestation_json": json_encoder.encode(attestation), }, desc="update_remote_attestion", ) - def remove_attestation_renewal(self, group_id, user_id): + async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: """Remove an attestation that we thought we should renew, but actually shouldn't. Ideally this would never get called as we would never incorrectly try and do attestations for local users on local groups. Args: - group_id (str) - user_id (str) + group_id + user_id """ - return self.db.simple_delete( + return await self.db_pool.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1237,14 +1352,11 @@ class GroupServerStore(GroupServerWorkerStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def delete_group(self, group_id): + async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. Args: - group_id (str) - - Returns: - Deferred + group_id: The group ID to delete. """ def _delete_group_txn(txn): @@ -1264,8 +1376,8 @@ class GroupServerStore(GroupServerWorkerStore): ] for table in tables: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.db.runInteraction("delete_group", _delete_group_txn) + await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/databases/main/keys.py index 4e1642a27a..ad43bb05ab 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,6 +16,7 @@ import itertools import logging +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -41,16 +42,17 @@ class KeyStore(SQLBaseStore): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -86,14 +88,19 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.db.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + async def store_server_verify_keys( + self, + from_server: str, + ts_added_ms: int, + verify_keys: Iterable[Tuple[str, str, FetchKeyResult]], + ) -> None: """Stores NACL verification keys for remote servers. Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + from_server: Where the verification keys were looked up + ts_added_ms: The time to record that the key was added + verify_keys: keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ @@ -115,15 +122,9 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.db.runInteraction( + await self.db_pool.runInteraction( "store_server_verify_keys", - self.db.simple_upsert_many_txn, + self.db_pool.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -134,24 +135,34 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, - ).addCallback(_invalidate) + ) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + invalidate = self._get_server_verify_key.invalidate + for i in invalidations: + invalidate((i,)) + + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -169,7 +180,9 @@ class KeyStore(SQLBaseStore): desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -178,8 +191,7 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -190,7 +202,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self.db.simple_select_list_txn( + rows = self.db_pool.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, @@ -205,4 +217,6 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) + return await self.db_pool.runInteraction( + "get_server_keys_json", _get_server_keys_json_txn + ) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 8aecd414c2..86557d5512 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Iterable, List, Optional, Tuple + from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryBackgroundUpdateStore, self).__init__( database, db_conn, hs ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", @@ -34,15 +36,16 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -57,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) - def store_local_media( + async def store_local_media( self, media_id, media_type, @@ -66,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, user_id, url_cache=None, - ): - return self.db.simple_insert( + ) -> None: + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -81,7 +84,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) - def get_url_cache(self, url, ts): + async def mark_local_media_as_safe(self, media_id: str) -> None: + """Mark a local media as safe from quarantining.""" + await self.db_pool.simple_update_one( + table="local_media_repository", + keyvalues={"media_id": media_id}, + updatevalues={"safe_from_quarantine": True}, + desc="mark_local_media_as_safe", + ) + + async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. @@ -127,12 +139,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) ) - return self.db.runInteraction("get_url_cache", get_url_cache_txn) + return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache( + async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -146,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_url_cache", ) - def get_local_media_thumbnails(self, media_id): - return self.db.simple_select_list( + async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -160,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media_thumbnails", ) - def store_local_thumbnail( + async def store_local_thumbnail( self, media_id, thumbnail_width, @@ -169,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -182,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -198,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_cached_remote_media", ) - def store_cached_remote_media( + async def store_cached_remote_media( self, origin, media_id, @@ -208,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self.db.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -223,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, local_media, remote_media, time_ms): + async def update_cached_last_access_time( + self, + local_media: Iterable[str], + remote_media: Iterable[Tuple[str, str]], + time_ms: int, + ): """Updates the last access time of the given media Args: - local_media (iterable[str]): Set of media_ids - remote_media (iterable[(str, str)]): Set of (server_name, media_id) + local_media: Set of media_ids + remote_media: Set of (server_name, media_id) time_ms: Current time in milliseconds """ @@ -253,12 +272,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn ) - def get_remote_media_thumbnails(self, origin, media_id): - return self.db.simple_select_list( + async def get_remote_media_thumbnails( + self, origin: str, media_id: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -272,7 +293,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail( + async def store_remote_media_thumbnail( self, origin, media_id, @@ -283,7 +304,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -298,33 +319,35 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_remote_media_thumbnail", ) - def get_remote_media_before(self, before_ts): + async def get_remote_media_before(self, before_ts): sql = ( "SELECT media_origin, media_id, filesystem_id" " FROM remote_media_cache" " WHERE last_access_ts < ?" ) - return self.db.execute( - "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts + return await self.db_pool.execute( + "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts ) - def delete_remote_media(self, media_origin, media_id): + async def delete_remote_media(self, media_origin: str, media_id: str) -> None: def delete_remote_media_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) + await self.db_pool.runInteraction( + "delete_remote_media", delete_remote_media_txn + ) - def get_expired_url_cache(self, now_ts): + async def get_expired_url_cache(self, now_ts: int) -> List[str]: sql = ( "SELECT media_id FROM local_media_repository_url_cache" " WHERE expires_ts < ?" @@ -336,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_expired_url_cache", _get_expired_url_cache_txn ) @@ -349,9 +372,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) + return await self.db_pool.runInteraction( + "delete_url_cache", _delete_url_cache_txn + ) - def get_url_cache_media_before(self, before_ts): + async def get_url_cache_media_before(self, before_ts: int) -> List[str]: sql = ( "SELECT media_id FROM local_media_repository" " WHERE created_ts < ? AND url_cache IS NOT NULL" @@ -363,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) @@ -380,6 +405,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/databases/main/metrics.py index dad5bbc602..686052bd83 100644 --- a/synapse/storage/data_stores/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -15,15 +15,13 @@ import typing from collections import Counter -from twisted.internet import defer - from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.event_push_actions import ( +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.database import Database class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): @@ -31,7 +29,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. @@ -66,11 +64,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) return txn.fetchall() - res = await self.db.runInteraction("read_forward_extremities", fetch) + res = await self.db_pool.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = Counter([x[0] for x in res]) - @defer.inlineCallbacks - def count_daily_messages(self): + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. @@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_messages", _count_messages) - return ret + return await self.db_pool.runInteraction("count_messages", _count_messages) - @defer.inlineCallbacks - def count_daily_sent_messages(self): + async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then thats your own fault. @@ -109,11 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) - return ret + return await self.db_pool.runInteraction( + "count_daily_sent_messages", _count_messages + ) - @defer.inlineCallbacks - def count_daily_active_rooms(self): + async def count_daily_active_rooms(self): def _count(txn): sql = """ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events @@ -124,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_daily_active_rooms", _count) - return ret + return await self.db_pool.runInteraction("count_daily_active_rooms", _count) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e459cf49a0..1d793d3deb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List - -from twisted.internet import defer +from typing import Dict, List from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database, make_in_list_sql_clause +from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -29,17 +27,17 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @cached(num_args=0) - def get_monthly_active_count(self): + async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users Returns: - Defered[int]: Number of current monthly active users + Number of current monthly active users """ def _count_users(txn): @@ -48,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - return self.db.runInteraction("count_users", _count_users) + return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - def get_monthly_active_count_by_service(self): + async def get_monthly_active_count_by_service(self) -> Dict[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table @@ -59,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): method to return anything other than native matrix users. Returns: - Deferred[dict]: dict that includes a mapping between app_service_id - and the number of occurrences. + A mapping between app_service_id and the number of occurrences. """ @@ -76,7 +73,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): result = txn.fetchall() return dict(result) - return self.db.runInteraction("count_users_by_service", _count_users_by_service) + return await self.db_pool.runInteraction( + "count_users_by_service", _count_users_by_service + ) async def get_registered_reserved_users(self) -> List[str]: """Of the reserved threepids defined in config, retrieve those that are associated @@ -99,17 +98,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + + Arguments: + user_id: user to add/update + Return: + Timestamp since last seen, None if never seen """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", @@ -119,7 +119,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._limit_usage_by_mau = hs.config.limit_usage_by_mau @@ -128,7 +128,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # Do not add more reserved users than the total allowable number # cur = LoggingTransaction( - self.db.new_transaction( + self.db_pool.new_transaction( db_conn, "initialise_mau_threepids", [], @@ -162,7 +162,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): is_support = self.is_support_user_txn(txn, user_id) if not is_support: # We do this manually here to avoid hitting #6791 - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -246,20 +246,16 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) reserved_users = await self.get_registered_reserved_users() - await self.db.runInteraction( + await self.db_pool.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server Args: - user_id (str): user to add/update - - Returns: - Deferred + user_id: user to add/update """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -269,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # _initialise_reserved_users reasoning that it would be very strange to # include a support user in this context. - is_support = yield self.is_support_user(user_id) + is_support = await self.is_support_user(user_id) if is_support: return - yield self.db.runInteraction( + await self.db_pool.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -303,7 +299,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.db.simple_upsert_txn( + is_insert = self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -320,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): return is_insert - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -330,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) + is_guest = await self.is_guest(user_id) if is_guest: return - is_trial = yield self.is_trial_user(user_id) + is_trial = await self.is_trial_user(user_id) if is_trial: return - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + last_seen_timestamp = await self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy @@ -350,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. if not self._limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) else: - count = yield self.get_monthly_active_count() + count = await self.get_monthly_active_count() if count < self._max_mau_value: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/databases/main/openid.py index cc21437e92..2aac64901b 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,9 +1,13 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db.simple_insert( + async def insert_open_id_token( + self, token: str, ts_valid_until_ms: int, user_id: str + ) -> None: + await self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, @@ -13,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -28,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/databases/main/presence.py index dab31e0c2d..c9f655dfb7 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -13,23 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from typing import List, Tuple +from synapse.api.presence import UserPresenceState from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + async def update_presence(self, presence_states): + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) with stream_ordering_manager as stream_orderings: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, @@ -46,7 +45,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -73,9 +72,32 @@ class PresenceStore(SQLBaseStore): ) txn.execute(sql + clause, [stream_id] + list(args)) - def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for presence replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_presence_updates_txn(txn): sql = """ @@ -89,9 +111,17 @@ class PresenceStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + updates = [(row[0], row[1:]) for row in txn] + + upper_bound = current_id + limited = False + if len(updates) >= limit: + upper_bound = updates[-1][0] + limited = True - return self.db.runInteraction( + return updates, upper_bound, limited + + return await self.db_pool.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) @@ -100,13 +130,10 @@ class PresenceStore(SQLBaseStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, ) - def get_presence_for_users(self, user_ids): - rows = yield self.db.simple_select_many_batch( + async def get_presence_for_users(self, user_ids): + rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -130,24 +157,3 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db.simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db.simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/databases/main/profile.py index bfc9369f0b..d2e0685e9e 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,19 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer +from typing import Any, Dict, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.roommember import ProfileInfo +from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: - profile = yield self.db.simple_select_one( + profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -41,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - def get_profile_displayname(self, user_localpart): - return self.db.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", desc="get_profile_displayname", ) - def get_profile_avatar_url(self, user_localpart): - return self.db.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_from_remote_profile_cache(self, user_id): - return self.db.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -66,21 +66,25 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - def create_profile(self, user_localpart): - return self.db.simple_insert( + async def create_profile(self, user_localpart: str) -> None: + await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db.simple_update_one( + async def set_profile_displayname( + self, user_localpart: str, new_displayname: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, desc="set_profile_displayname", ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db.simple_update_one( + async def set_profile_avatar_url( + self, user_localpart: str, new_avatar_url: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -89,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: """Ensure we are caching the remote user's profiles. This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -106,8 +112,10 @@ class ProfileStore(ProfileWorkerStore): desc="add_remote_profile_cache", ) - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db.simple_update( + async def update_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> int: + return await self.db_pool.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ @@ -118,20 +126,21 @@ class ProfileStore(ProfileWorkerStore): desc="update_remote_profile_cache", ) - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id): """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + subscribed = await self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.db.simple_delete( + await self.db_pool.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -144,18 +153,17 @@ class ProfileStore(ProfileWorkerStore): txn.execute(sql, (last_checked,)) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +174,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index a93e1ef198..ea833829ae 100644 --- a/synapse/storage/data_stores/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -14,36 +14,35 @@ # limitations under the License. import logging -from typing import Any, Tuple +from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> Set[int]: """Deletes room history before a certain point Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): + room_id: + token: A topological token to delete events before + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + The set of state groups that are referenced by deleted events. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -62,6 +61,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # event_json # event_push_actions # event_reference_hashes + # event_relations # event_search # event_to_state_groups # events @@ -209,6 +209,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "event_edges", "event_forward_extremities", "event_reference_hashes", + "event_relations", "event_search", "rejections", ): @@ -281,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): return referenced_state_groups - def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> List[int]: """Deletes all record of a room Args: - room_id (str) + room_id Returns: - Deferred[List[int]]: The list of state groups to delete. + The list of state groups to delete. """ - - return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + return await self.db_pool.runInteraction( + "purge_room", self._purge_room_txn, room_id + ) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before @@ -361,7 +363,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "event_push_summary", "pusher_throttle", "group_summary_rooms", - "local_invites", "room_account_data", "room_tags", "local_current_membership", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ef8f40959f..0de802a86b 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -16,40 +16,37 @@ import abc import logging -from typing import Union - -from canonicaljson import json - -from twisted.internet import defer +from typing import List, Tuple, Union from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.database import Database +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ChainedIdGenerator -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -def _load_rules(rawrules, enabled_map): +def _load_rules(rawrules, enabled_map, use_new_defaults=False): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) + rule["conditions"] = db_to_json(rawrule["conditions"]) + rule["actions"] = db_to_json(rawrule["actions"]) rule["default"] = False ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) + rules = list(list_with_base_rules(ruleslist, use_new_defaults)) for i, rule in enumerate(rules): rule_id = rule["rule_id"] @@ -79,19 +76,19 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + self._push_rules_stream_id_gen = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) # type: Union[StreamIdGenerator, SlavedIdTracker] else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - push_rules_prefill, push_rules_id = self.db.get_cache_dict( + push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -105,6 +102,8 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. @@ -114,9 +113,9 @@ class PushRulesWorkerStore( """ raise NotImplementedError() - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self.db.simple_select_list( + @cached(max_entries=5000) + async def get_push_rules_for_user(self, user_id): + rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -132,15 +131,15 @@ class PushRulesWorkerStore( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + enabled_map = await self.get_push_rules_enabled_for_user(user_id) - rules = _load_rules(rows, enabled_map) + use_new_defaults = user_id in self._users_new_default_push_rules - return rules + return _load_rules(rows, enabled_map, use_new_defaults) - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self.db.simple_select_list( + @cached(max_entries=5000) + async def get_push_rules_enabled_for_user(self, user_id): + results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -148,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -162,23 +163,20 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, ) - def bulk_get_push_rules(self, user_ids): + async def bulk_get_push_rules(self, user_ids): if not user_ids: return {} results = {user_id: [] for user_id in user_ids} - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -191,21 +189,26 @@ class PushRulesWorkerStore( for row in rows: results.setdefault(row["user_name"], []).append(row) - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + use_new_defaults = user_id in self._users_new_default_push_rules + + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), use_new_defaults, + ) return results - @defer.inlineCallbacks - def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + async def copy_push_rule_from_room_to_room( + self, new_room_id: str, user_id: str, rule: dict + ) -> None: """Copy a single push rule from one room to another for a specific user. Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. + new_room_id: ID of the new room. + user_id : ID of user the push rule belongs to. + rule: A push rule. """ # Create new rule id rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) @@ -217,7 +220,7 @@ class PushRulesWorkerStore( condition["pattern"] = new_room_id # Add the rule for the new room - yield self.add_push_rule( + await self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], @@ -225,20 +228,19 @@ class PushRulesWorkerStore( actions=rule["actions"], ) - @defer.inlineCallbacks - def copy_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): + async def copy_push_rules_from_room_to_room_for_user( + self, old_room_id: str, new_room_id: str, user_id: str + ) -> None: """Copy all of the push rules from one room to another for a specific user. Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. + old_room_id: ID of the old room. + new_room_id: ID of the new room. + user_id: ID of user to copy push rules for. """ # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) + user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room for rule in user_push_rules: @@ -247,96 +249,20 @@ class PushRulesWorkerStore( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): - yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids() - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = { - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - } - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user + await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, - inlineCallbacks=True, ) - def bulk_get_push_rules_enabled(self, user_ids): + async def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: return {} results = {user_id: {} for user_id in user_ids} - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -348,30 +274,59 @@ class PushRulesWorkerStore( results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" + async def get_all_push_rule_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for push_rules replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) + sql = """ + SELECT stream_id, user_id + FROM push_rules_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] - return self.db.runInteraction( + limited = False + upper_bound = current_id + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited + + return await self.db_pool.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( + async def add_push_rule( self, user_id, rule_id, @@ -380,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore): actions, before=None, after=None, - ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids + ) -> None: + conditions_json = json_encoder.encode(conditions) + actions_json = json_encoder.encode(actions) + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + if before or after: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -400,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore): after, ) else: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -431,7 +387,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self.db.simple_select_one_txn( + res = self.db_pool.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -554,7 +510,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="push_rules", values={ @@ -584,20 +540,19 @@ class PushRuleStore(PushRulesWorkerStore): }, ) - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): + async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ Delete a push rule. Args specify the row to be deleted and can be any of the columns in the push_rule table, but below are the standard ones Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted + user_id: The matrix ID of the push rule owner + rule_id: The rule_id of the rule to be deleted """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self.db.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -605,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db.runInteraction( + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + + await self.db_pool.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering, ) - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db.runInteraction( + async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + + await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -632,7 +588,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -649,9 +605,10 @@ class PushRuleStore(PushRulesWorkerStore): op="ENABLE" if enabled else "DISABLE", ) - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) + async def set_push_rule_actions( + self, user_id, rule_id, actions, is_default_rule + ) -> None: + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: @@ -672,7 +629,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -689,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.db.runInteraction( + with await self._push_rules_stream_id_gen.get_next() as stream_id: + event_stream_ordering = self._stream_id_gen.get_current_token() + + await self.db_pool.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -711,7 +669,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self.db.simple_insert_txn(txn, "push_rules_stream", values=values) + self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) @@ -719,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] + return self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/databases/main/pusher.py index 547b9d69cb..c388468273 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -15,14 +15,12 @@ # limitations under the License. import logging -from typing import Iterable, Iterator +from typing import Iterable, Iterator, List, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json -from twisted.internet import defer - -from synapse.storage._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util.caches.descriptors import cached, cachedList logger = logging.getLogger(__name__) @@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore): Drops any rows whose data cannot be decoded """ for r in rows: - dataJson = r["data"] + data_json = r["data"] try: - r["data"] = json.loads(dataJson) + r["data"] = db_to_json(data_json) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], - dataJson, + data_json, e.args[0], ) continue yield r - @defer.inlineCallbacks - def user_has_pusher(self, user_id): - ret = yield self.db.simple_select_one_onecol( + async def user_has_pusher(self, user_id): + ret = await self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_by_user_id(self, user_id): return self.get_pushers_by({"user_name": user_id}) - @defer.inlineCallbacks - def get_pushers_by(self, keyvalues): - ret = yield self.db.simple_select_list( + async def get_pushers_by(self, keyvalues): + ret = await self.db_pool.simple_select_list( "pushers", keyvalues, [ @@ -87,104 +83,91 @@ class PusherWorkerStore(SQLBaseStore): ) return self._decode_pushers_rows(ret) - @defer.inlineCallbacks - def get_all_pushers(self): + async def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.db.runInteraction("get_all_pushers", get_pushers) - return rows + return await self.db_pool.runInteraction("get_all_pushers", get_pushers) - def get_all_updated_pushers(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed(([], [])) - - def get_all_updated_pushers_txn(txn): - sql = ( - "SELECT id, user_name, access_token, profile_tag, kind," - " app_id, app_display_name, device_display_name, pushkey, ts," - " lang, data" - " FROM pushers" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - updated = txn.fetchall() + async def get_all_updated_pushers_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for pushers replication stream. - sql = ( - "SELECT stream_id, user_id, app_id, pushkey" - " FROM deleted_pushers" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - deleted = txn.fetchall() - - return updated, deleted + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. - return self.db.runInteraction( - "get_all_updated_pushers", get_all_updated_pushers_txn - ) + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. - def get_all_updated_pushers_rows(self, last_id, current_id, limit): - """Get all the pushers that have changed between the given tokens. + The token returned can be used in a subsequent call to this + function to get further updatees. - Returns: - Deferred(list(tuple)): each tuple consists of: - stream_id (str) - user_id (str) - app_id (str) - pushkey (str) - was_deleted (bool): whether the pusher was added/updated (False) - or deleted (True) + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_updated_pushers_rows_txn(txn): - sql = ( - "SELECT id, user_name, app_id, pushkey" - " FROM pushers" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC LIMIT ?" - ) + sql = """ + SELECT id, user_name, app_id, pushkey + FROM pushers + WHERE ? < id AND id <= ? + ORDER BY id ASC LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) - results = [list(row) + [False] for row in txn] - - sql = ( - "SELECT stream_id, user_id, app_id, pushkey" - " FROM deleted_pushers" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) + updates = [ + (stream_id, (user_name, app_id, pushkey, False)) + for stream_id, user_name, app_id, pushkey in txn + ] + + sql = """ + SELECT stream_id, user_id, app_id, pushkey + FROM deleted_pushers + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) + updates.extend( + (stream_id, (user_name, app_id, pushkey, True)) + for stream_id, user_name, app_id, pushkey in txn + ) + + updates.sort() # Sort so that they're ordered by stream id - results.extend(list(row) + [True] for row in txn) - results.sort() # Sort so that they're ordered by stream id + limited = False + upper_bound = current_id + if len(updates) >= limit: + limited = True + upper_bound = updates[-1][0] - return results + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) - @cachedInlineCallbacks(num_args=1, max_entries=15000) - def get_if_user_has_pusher(self, user_id): + @cached(num_args=1, max_entries=15000) + async def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator raise NotImplementedError() @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, ) - def get_if_users_have_pushers(self, user_ids): - rows = yield self.db.simple_select_many_batch( + async def get_if_users_have_pushers(self, user_ids): + rows = await self.db_pool.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -197,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore): return result - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( + async def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self.db.simple_update_one( + ) -> None: + await self.db_pool.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, desc="update_pusher_last_stream_ordering", ) - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): + async def update_pusher_last_stream_ordering_and_success( + self, + app_id: str, + pushkey: str, + user_id: str, + last_stream_ordering: int, + last_success: int, + ) -> bool: """Update the last stream ordering position we've processed up to for the given pusher. Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) + app_id + pushkey + user_id + last_stream_ordering + last_success Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. + True if the pusher still exists; False if it has been deleted. """ - updated = yield self.db.simple_update( + updated = await self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -236,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db.simple_update( + async def update_pusher_failing_since( + self, app_id, pushkey, user_id, failing_since + ) -> None: + await self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, desc="update_pusher_failing_since", ) - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self.db.simple_select_list( + async def get_throttle_params_by_room(self, pusher_id): + res = await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -263,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore): return params_by_room - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): + async def set_throttle_params(self, pusher_id, room_id, params) -> None: # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, @@ -280,8 +266,7 @@ class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -295,11 +280,11 @@ class PusherStore(PusherWorkerStore): data, last_stream_ordering, profile_tag="", - ): - with self._pushers_id_gen.get_next() as stream_id: + ) -> None: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -324,21 +309,22 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.db.runInteraction( + await self.db_pool.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, (user_id,), ) - @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): + async def delete_pusher_by_app_id_pushkey_user_id( + self, app_id, pushkey, user_id + ) -> None: def delete_pusher_txn(txn, stream_id): self._invalidate_cache_and_stream( txn, self.get_if_user_has_pusher, (user_id,) ) - self.db.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -347,7 +333,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -358,5 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: - yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + with await self._pushers_id_gen.get_next() as stream_id: + await self.db_pool.runInteraction( + "delete_pusher", delete_pusher_txn, stream_id + ) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/databases/main/receipts.py index cebdcd409f..4a0d5a320e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -16,15 +16,16 @@ import abc import logging - -from canonicaljson import json +from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util import json_encoder +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -39,7 +40,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( @@ -55,14 +56,16 @@ class ReceiptsWorkerStore(SQLBaseStore): """ raise NotImplementedError() - @cachedInlineCallbacks() - def get_users_with_read_receipts_in_room(self, room_id): - receipts = yield self.get_receipts_for_room(room_id, "m.read") + @cached() + async def get_users_with_read_receipts_in_room(self, room_id): + receipts = await self.get_receipts_for_room(room_id, "m.read") return {r["user_id"] for r in receipts} @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self.db.simple_select_list( + async def get_receipts_for_room( + self, room_id: str, receipt_type: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,8 +73,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -83,9 +88,9 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - @cachedInlineCallbacks(num_args=2) - def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.db.simple_select_list( + @cached(num_args=2) + async def get_receipts_for_user(self, user_id, receipt_type): + rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -94,8 +99,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - @defer.inlineCallbacks - def get_receipts_for_user_with_orderings(self, user_id, receipt_type): + async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( "SELECT rl.room_id, rl.event_id," @@ -109,7 +113,9 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) + rows = await self.db_pool.runInteraction( + "get_receipts_for_user_with_orderings", f + ) return { row[0]: { "event_id": row[1], @@ -119,56 +125,61 @@ class ReceiptsWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def get_linearized_receipts_for_rooms( + self, room_ids: List[str], to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """Get receipts for multiple rooms for sending to clients. Args: - room_ids (list): List of room_ids. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches + room_id: List of room_ids. + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: - list: A list of receipts. + A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` - room_ids = yield self._receipts_stream_cache.get_entities_changed( + room_ids = self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) - results = yield self._get_linearized_receipts_for_rooms( + results = await self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) return [ev for res in results.values() for ev in res] - def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + async def get_linearized_receipts_for_room( + self, room_id: str, to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """Get receipts for a single room for sending to clients. Args: - room_ids (str): The room id. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches + room_ids: The room id. + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: - Deferred[list]: A list of receipts. + A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): - defer.succeed([]) + return [] - return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) - @cachedInlineCallbacks(num_args=3, tree=True) - def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + @cached(num_args=3, tree=True) + async def _get_linearized_receipts_for_room( + self, room_id: str, to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """See get_linearized_receipts_for_room """ @@ -188,11 +199,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, to_key)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) return rows - rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) + rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -201,7 +212,7 @@ class ReceiptsWorkerStore(SQLBaseStore): for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] - ] = json.loads(row["data"]) + ] = db_to_json(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @@ -209,9 +220,8 @@ class ReceiptsWorkerStore(SQLBaseStore): cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, - inlineCallbacks=True, ) - def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: return {} @@ -238,9 +248,9 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key] + list(args)) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - txn_results = yield self.db.runInteraction( + txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) @@ -258,7 +268,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) - receipt_type[row["user_id"]] = json.loads(row["data"]) + receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] @@ -266,31 +276,86 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - def get_all_updated_receipts(self, last_id, current_id, limit=None): + async def get_users_sent_receipts_between( + self, last_id: int, current_id: int + ) -> List[str]: + """Get all users who sent receipts between `last_id` exclusive and + `current_id` inclusive. + + Returns: + The list of users. + """ + if last_id == current_id: return defer.succeed([]) + def _get_users_sent_receipts_between_txn(txn): + sql = """ + SELECT DISTINCT user_id FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (last_id, current_id)) + + return [r[0] for r in txn] + + return await self.db_pool.runInteraction( + "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn + ) + + async def get_all_updated_receipts( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for receipts replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) + sql = """ + SELECT stream_id, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] + + limited = False + upper_bound = current_id + + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] - return [r[0:5] + (json.loads(r[5]),) for r in txn] + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) def _invalidate_get_users_with_receipts_in_room( - self, room_id, receipt_type, user_id + self, room_id: str, receipt_type: str, user_id: str ): if receipt_type != "m.read": return @@ -300,10 +365,10 @@ class ReceiptsWorkerStore(SQLBaseStore): room_id, None, update_metrics=False ) - # first handle the Deferred case - if isinstance(res, defer.Deferred): - if res.called: - res = res.result + # first handle the ObservableDeferred case + if isinstance(res, ObservableDeferred): + if res.has_called(): + res = res.get_result() else: res = None @@ -316,7 +381,7 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( @@ -338,7 +403,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self.db.simple_select_one_txn( + res = self.db_pool.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -391,7 +456,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ @@ -402,7 +467,7 @@ class ReceiptsStore(ReceiptsWorkerStore): values={ "stream_id": stream_id, "event_id": event_id, - "data": json.dumps(data), + "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -416,15 +481,21 @@ class ReceiptsStore(ReceiptsWorkerStore): return rx_ts - @defer.inlineCallbacks - def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + async def insert_receipt( + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: dict, + ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. """ if not event_ids: - return + return None if len(event_ids) == 1: linearized_event_id = event_ids[0] @@ -451,13 +522,12 @@ class ReceiptsStore(ReceiptsWorkerStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.db.runInteraction( + linearized_event_id = await self.db_pool.runInteraction( "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: - event_ts = yield self.db.runInteraction( + with await self._receipts_id_gen.get_next() as stream_id: + event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -479,14 +549,16 @@ class ReceiptsStore(ReceiptsWorkerStore): now - event_ts, ) - yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) + await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) max_persisted_id = self._receipts_id_gen.get_current_token() return stream_id, max_persisted_id - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db.runInteraction( + async def insert_graph_receipt( + self, room_id, receipt_type, user_id, event_ids, data + ): + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -512,7 +584,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -521,14 +593,14 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), + "event_ids": json_encoder.encode(event_ids), + "data": json_encoder.encode(data), }, ) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/databases/main/registration.py index 9768981891..01f20c03c2 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,20 +17,17 @@ import logging import re -from typing import Optional - -from six import iterkeys - -from twisted.internet import defer -from twisted.internet.defer import Deferred +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 @@ -38,15 +35,19 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + @cached() - def get_user_by_id(self, user_id): - return self.db.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -65,19 +66,15 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_id", ) - @defer.inlineCallbacks - def is_trial_user(self, user_id): + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config Args: - user_id (str) - - Returns: - Deferred[bool] + user_id: The user to check for trial status. """ - info = yield self.get_user_by_id(user_id) + info = await self.get_user_by_id(user_id) if not info: return False @@ -87,60 +84,61 @@ class RegistrationWorkerStore(SQLBaseStore): return is_trial @cached() - def get_user_by_access_token(self, token): + async def get_user_by_access_token(self, token: str) -> Optional[dict]: """Get a user from the given access token. Args: - token (str): The access token of a user. + token: The access token of a user. Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): + @cached() + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]: """Get the expiration timestamp for the account bearing a given user ID. Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). + None, if the account has no expiration timestamp, otherwise int + representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", allow_none=True, desc="get_expiration_ts_for_user", ) - return res - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): + async def set_account_validity_for_user( + self, + user_id: str, + expiration_ts: int, + email_sent: bool, + renewal_token: Optional[str] = None, + ) -> None: """Updates the account validity properties of the given account, with the given values. Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds + user_id: ID of the account to update properties for. + expiration_ts: New expiration date, as a timestamp in milliseconds since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity + email_sent: True means a renewal email has been sent for this account + and there's no need to send another one for the current validity period. - renewal_token (str): Renewal token the user can use to extend the validity + renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. """ def set_account_validity_for_user_txn(txn): - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -154,75 +152,69 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): + async def set_renewal_token_for_user( + self, user_id: str, renewal_token: str + ) -> None: """Defines a renewal token for a given user. Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the + user_id: ID of the user to set the renewal token for. + renewal_token: Random unique string that will be used to renew the user's account. Raises: StoreError: The provided token is already set for another user. """ - yield self.db.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, desc="set_renewal_token_for_user", ) - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): + async def get_user_from_renewal_token(self, renewal_token: str) -> str: """Get a user ID from a renewal token. Args: - renewal_token (str): The renewal token to perform the lookup with. + renewal_token: The renewal token to perform the lookup with. Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. + The ID of the user to which the token belongs. """ - res = yield self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", desc="get_user_from_renewal_token", ) - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. Args: - user_id (str): The user ID to lookup a token for. + user_id: The user ID to lookup a token for. Returns: - defer.Deferred[str]: The renewal token associated with this user ID. + The renewal token associated with this user ID. """ - res = yield self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", desc="get_renewal_token_for_user", ) - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): + async def get_users_expiring_soon(self) -> List[Dict[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + A list of dictionaries mapping user ID to expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -232,58 +224,54 @@ class RegistrationWorkerStore(SQLBaseStore): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - res = yield self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), self.config.account_validity.renew_at, ) - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): + async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: """Sets or unsets the flag that indicates whether a renewal email has been sent to the user (and the user hasn't renewed their account yet). Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent + user_id: ID of the user to set/unset the flag for. + email_sent: Flag which indicates whether a renewal email has been sent to this user. """ - yield self.db.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, desc="set_renewal_mail_status", ) - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): + async def delete_account_validity_for_user(self, user_id: str) -> None: """Deletes the entry for the given user in the account validity table, removing their expiration date and renewal token. Args: - user_id (str): ID of the user to remove from the account validity table. + user_id: ID of the user to remove from the account validity table. """ - yield self.db.simple_delete_one( + await self.db_pool.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", ) - async def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test + user: user ID of the user to test - Returns (bool): + Returns: true iff the user is a server admin, false otherwise. """ - res = await self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -293,28 +281,27 @@ class RegistrationWorkerStore(SQLBaseStore): return bool(res) if res else False - def set_server_admin(self, user, admin): + async def set_server_admin(self, user: UserID, admin: bool) -> None: """Sets whether a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test - admin (bool): true iff the user is to be a server admin, - false otherwise. + user: user ID of the user to test + admin: true iff the user is to be a server admin, false otherwise. """ def set_server_admin_txn(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} ) self._invalidate_cache_and_stream( txn, self.get_user_by_id, (user.to_string(),) ) - return self.db.runInteraction("set_server_admin", set_server_admin_txn) + await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," + "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" @@ -322,43 +309,42 @@ class RegistrationWorkerStore(SQLBaseStore): ) txn.execute(sql, (token,)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if rows: return rows[0] return None - @cachedInlineCallbacks() - def is_real_user(self, user_id): + @cached() + async def is_real_user(self, user_id: str) -> bool: """Determines if the user is a real user, ie does not have a 'user_type'. Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user 'user_type' is null or empty string + True if user 'user_type' is null or empty string """ - res = yield self.db.runInteraction( + return await self.db_pool.runInteraction( "is_real_user", self.is_real_user_txn, user_id ) - return res @cached() - def is_support_user(self, user_id): + async def is_support_user(self, user_id: str) -> bool: """Determines if the user is of type UserTypes.SUPPORT Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT + True if user is of type UserTypes.SUPPORT """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) def is_real_user_txn(self, txn, user_id): - res = self.db.simple_select_one_onecol_txn( + res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -368,7 +354,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self.db.simple_select_one_onecol_txn( + res = self.db_pool.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -377,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore): ) return True if res == UserTypes.SUPPORT else False - def get_users_by_id_case_insensitive(self, user_id): + async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: """Gets users that match user_id case insensitively. - Returns a mapping of user_id -> password_hash. + + Returns: + A mapping of user_id -> password_hash. """ def f(txn): @@ -387,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return dict(txn) - return self.db.runInteraction("get_users_by_id_case_insensitive", f) + return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -401,7 +389,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -409,21 +397,19 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_external_id", ) - @defer.inlineCallbacks - def count_all_users(self): + async def count_all_users(self): """Counts all users registered on the homeserver.""" def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.db.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) - def count_daily_user_type(self): + async def count_daily_user_type(self) -> Dict[str, int]: """ Counts 1) native non guest users 2) native guests users @@ -452,10 +438,11 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) + return await self.db_pool.runInteraction( + "count_daily_user_type", _count_daily_user_type + ) - @defer.inlineCallbacks - def count_nonbridged_users(self): + async def count_nonbridged_users(self): def _count_users(txn): txn.execute( """ @@ -466,56 +453,31 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) - @defer.inlineCallbacks - def count_real_users(self): + async def count_real_users(self): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.db.runInteraction("count_real_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_real_users", _count_users) - @defer.inlineCallbacks - def find_next_generated_user_id_localpart(self): - """ - Gets the localpart of the next generated user ID. + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user - Generated user IDs are integers, so we find the largest integer user ID - already taken and return that plus one. + Returns: a (hopefully) free localpart """ - - def _find_next_generated_user_id(txn): - # We bound between '@0' and '@a' to avoid pulling the entire table - # out. - txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") - - regex = re.compile(r"^@(\d+):") - - max_found = 0 - - for (user_id,) in txn: - match = regex.search(user_id) - if match: - max_found = max(int(match.group(1)), max_found) - - return max_found + 1 - - return ( - ( - yield self.db.runInteraction( - "find_next_generated_user_id", _find_next_generated_user_id - ) - ) + next_id = await self.db_pool.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn ) + return str(next_id) + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid @@ -526,7 +488,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: The user ID or None if no user id/threepid mapping exists """ - user_id = await self.db.runInteraction( + user_id = await self.db_pool.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -542,7 +504,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self.db.simple_select_one_txn( + ret = self.db_pool.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -553,61 +515,57 @@ class RegistrationWorkerStore(SQLBaseStore): return ret["user_id"] return None - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.db.simple_upsert( + async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self.db.simple_select_list( + async def user_get_threepids(self, user_id): + return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], "user_get_threepids", ) - return ret - def user_delete_threepid(self, user_id, medium, address): - return self.db.simple_delete( + async def user_delete_threepid(self, user_id, medium, address) -> None: + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", ) - def user_delete_threepids(self, user_id: str): + async def user_delete_threepids(self, user_id: str) -> None: """Delete all threepid this user has bound Args: user_id: The user id to delete all threepids of """ - return self.db.simple_delete( + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", ) - def add_user_bound_threepid(self, user_id, medium, address, id_server): + async def add_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ): """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred + user_id + medium + address + id_server """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -620,41 +578,40 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) - def user_get_bound_threepids(self, user_id): + async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. Args: - user_id (str): The ID of the user to retrieve threepids for + user_id: The ID of the user to retrieve threepids for Returns: - Deferred[list[dict]]: List of dictionaries containing the following: + List of dictionaries containing the following keys: medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db.simple_select_list( + return await self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], desc="user_get_bound_threepids", ) - def remove_user_bound_threepid(self, user_id, medium, address, id_server): + async def remove_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ) -> None: """The server proxied an unbind request to the given identity server on behalf of the given user, so we remove the mapping of threepid to identity server. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred + user_id + medium + address + id_server """ - return self.db.simple_delete( + await self.db_pool.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -665,37 +622,39 @@ class RegistrationWorkerStore(SQLBaseStore): desc="remove_user_bound_threepid", ) - def get_id_servers_user_bound(self, user_id, medium, address): + async def get_id_servers_user_bound( + self, user_id: str, medium: str, address: str + ) -> List[str]: """Get the list of identity servers that the server proxied bind requests to for given user and threepid Args: - user_id (str) - medium (str) - address (str) + user_id: The user to query for identity servers. + medium: The medium to query for identity servers. + address: The address to query for identity servers. Returns: - Deferred[list[str]]: Resolves to a list of identity servers + A list of identity servers """ - return self.db.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", desc="get_id_servers_user_bound", ) - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): + @cached() + async def get_user_deactivated_status(self, user_id: str) -> bool: """Retrieve the value for the `deactivated` property for the provided user. Args: - user_id (str): The ID of the user to retrieve the status for. + user_id: The ID of the user to retrieve the status for. Returns: - defer.Deferred(bool): The requested value. + True if the user was deactivated, false if the user is still active. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -705,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore): # Convert the integer into a boolean. return res == 1 - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): + async def get_threepid_validation_session( + self, + medium: Optional[str], + client_secret: str, + address: Optional[str] = None, + sid: Optional[str] = None, + validated: Optional[bool] = True, + ) -> Optional[Dict[str, Any]]: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str): A unique string provided by the client to help identify this + medium: The medium of the 3PID + client_secret: A unique string provided by the client to help identify this validation attempt - validated (bool|None): Whether sessions should be filtered by + address: The address of the 3PID + sid: The ID of the validation session + validated: Whether sessions should be filtered by whether they have been validated already or not. None to perform no filtering Returns: - Deferred[dict|None]: A dict containing the following: + A dict containing the following: * address - address of the 3pid * medium - medium of the 3pid * client_secret - a secret provided by the client for this validation session @@ -753,7 +717,7 @@ class RegistrationWorkerStore(SQLBaseStore): last_send_attempt, validated_at FROM threepid_validation_session WHERE %s """ % ( - " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + " AND ".join("%s = ?" % k for k in keyvalues.keys()), ) if validated is not None: @@ -762,57 +726,57 @@ class RegistrationWorkerStore(SQLBaseStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) - def delete_threepid_session(self, session_id): + async def delete_threepid_session(self, session_id: str) -> None: """Removes a threepid validation session from the database. This can be done after validation has been performed and whatever action was waiting on it has been carried out Args: - session_id (str): The ID of the session to delete + session_id: The ID of the session to delete """ def delete_threepid_session_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -822,18 +786,19 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.db.updates.register_noop_background_update("refresh_tokens_device_index") + self.db_pool.updates.register_noop_background_update( + "refresh_tokens_device_index" + ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ @@ -861,7 +826,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): (last_user, batch_size), ) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return True, 0 @@ -875,7 +840,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): logger.info("Marked %d rows as deactivated", rows_processed_nb) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -884,17 +849,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): else: return False, len(rows) - end, nb_processed = yield self.db.runInteraction( + end, nb_processed = await self.db_pool.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self.db.updates._end_background_update("users_set_deactivated_flag") + await self.db_pool.updates._end_background_update( + "users_set_deactivated_flag" + ) return nb_processed - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): + async def _bg_user_threepids_grandfather(self, progress, batch_size): """We now track which identity servers a user binds their 3PID to, so we need to handle the case of existing bindings where we didn't track this. @@ -915,20 +881,21 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.db.runInteraction( + await self.db_pool.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self.db.updates._end_background_update("user_threepids_grandfather") + await self.db_pool.updates._end_background_update("user_threepids_grandfather") return 1 class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -949,23 +916,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + async def add_access_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + ) -> None: """Adds an access token for the given user. Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the access token + valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. """ next_id = self._access_tokens_id_gen.get_next() - yield self.db.simple_insert( + await self.db_pool.simple_insert( "access_tokens", { "id": next_id, @@ -977,40 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) - def register_user( + async def register_user( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Attempts to register an account. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str): The ID of the appservice registering the user. - create_profile_with_displayname (unicode): Optionally create a profile for + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Whether this is a guest account being upgraded to a + non-guest account. + make_guest: True if the the new user should be guest, false to add a + regular user account. + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, + or None for a normal user. + shadow_banned: Whether the user is shadow-banned, i.e. they may be + told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. - - Returns: - Deferred """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "register_user", self._register_user, user_id, @@ -1021,6 +991,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ) def _register_user( @@ -1034,6 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ): user_id_obj = UserID.from_string(user_id) @@ -1044,7 +1016,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self.db.simple_select_one_txn( + self.db_pool.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1052,7 +1024,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1063,10 +1035,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) else: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "users", values={ @@ -1077,6 +1050,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) @@ -1109,11 +1083,10 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) - def record_user_external_id( + async def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> None: """Record a mapping from an external user id to a mxid Args: @@ -1121,7 +1094,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db.simple_insert( + await self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1131,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="record_user_external_id", ) - def user_set_password_hash(self, user_id, password_hash): + async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be @@ -1139,29 +1112,30 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db.runInteraction( + await self.db_pool.runInteraction( "user_set_password_hash", user_set_password_hash_txn ) - def user_set_consent_version(self, user_id, consent_version): + async def user_set_consent_version( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy the user has consented - to + user_id: full mxid of the user to update + consent_version: version of the policy the user has consented to Raises: StoreError(404) if user not found """ def f(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1169,23 +1143,24 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db.runInteraction("user_set_consent_version", f) + await self.db_pool.runInteraction("user_set_consent_version", f) - def user_set_consent_server_notice_sent(self, user_id, consent_version): + async def user_set_consent_server_notice_sent( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record that we have sent the user a server notice about privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy we have notified the - user about + user_id: full mxid of the user to update + consent_version: version of the policy we have notified the user about Raises: StoreError(404) if user not found """ def f(txn): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1193,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db.runInteraction("user_set_consent_server_notice_sent", f) + await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) - def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + async def user_delete_access_tokens( + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, + ) -> List[Tuple[str, int, Optional[str]]]: """ Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str): list of access_tokens IDs which should - *not* be deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_tokens ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: - defer.Deferred[list[str, int, str|None, int]]: a list of - (token, token id, device id) for each of the deleted tokens + A tuple of (token, token id, device id) for each of the deleted tokens """ def f(txn): @@ -1239,11 +1217,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return tokens_and_devices - return self.db.runInteraction("user_delete_access_tokens", f) + return await self.db_pool.runInteraction("user_delete_access_tokens", f) - def delete_access_token(self, access_token): + async def delete_access_token(self, access_token: str) -> None: def f(txn): - self.db.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1251,11 +1229,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, self.get_user_by_access_token, (access_token,) ) - return self.db.runInteraction("delete_access_token", f) + await self.db_pool.runInteraction("delete_access_token", f) - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self.db.simple_select_one_onecol( + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1265,36 +1243,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return res if res else False - def add_user_pending_deactivation(self, user_id): + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db.simple_insert( + await self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", ) - def del_user_pending_deactivation(self, user_id): + async def del_user_pending_deactivation(self, user_id: str) -> None: """ Removes the given user to the table of users who need to be parted from all the rooms they're in, effectively marking that user as fully deactivated. """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.db.simple_delete( + await self.db_pool.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1302,29 +1280,30 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="get_users_pending_deactivation", ) - def validate_threepid_session(self, session_id, client_secret, token, current_ts): + async def validate_threepid_session( + self, session_id: str, client_secret: str, token: str, current_ts: int + ) -> Optional[str]: """Attempt to validate a threepid session using a token Args: - session_id (str): The id of a validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - token (str): A validation token - current_ts (int): The current unix time in milliseconds. Used for - checking token expiry status + session_id: The id of a validation session + client_secret: A unique string provided by the client to help identify + this validation attempt + token: A validation token + current_ts: The current unix time in milliseconds. Used for checking + token expiry status Raises: ThreepidValidationError: if a matching validation token was not found or has expired Returns: - deferred str|None: A str representing a link to redirect the user - to if there is one. + A str representing a link to redirect the user to if there is one. """ # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self.db.simple_select_one_txn( + row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1333,16 +1312,23 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - - row = self.db.simple_select_one_txn( + row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1357,6 +1343,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link @@ -1367,7 +1358,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self.db.simple_update_txn( + self.db_pool.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1377,78 +1368,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return next_link # Return next_link if it exists - return self.db.runInteraction( + return await self.db_pool.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) - def upsert_threepid_validation_session( + async def start_or_continue_validation_session( self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self.db.simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - - def start_or_continue_validation_session( - self, - medium, - address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ): + medium: str, + address: str, + session_id: str, + client_secret: str, + send_attempt: int, + next_link: Optional[str], + token: str, + token_expires: int, + ) -> None: """Creates a new threepid validation session if it does not already exist and associates a new validation token with it Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - session_id (str): The id of this validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - next_link (str|None): The link to redirect the user to upon - successful validation - token (str): The validation token - token_expires (int): The timestamp for which after the token - will no longer be valid + medium: The medium of the 3PID + address: The address of the 3PID + session_id: The id of this validation session + client_secret: A unique string provided by the client to help + identify this validation attempt + send_attempt: The latest send_attempt on this session + next_link: The link to redirect the user to upon successful validation + token: The validation token + token_expires: The timestamp for which after the token will no + longer be valid """ def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1461,7 +1414,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1472,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) - def cull_expired_threepid_validation_tokens(self): + async def cull_expired_threepid_validation_tokens(self) -> None: """Remove threepid validation tokens with expiry dates that have passed""" def cull_expired_threepid_validation_tokens_txn(txn, ts): @@ -1485,24 +1438,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): DELETE FROM threepid_validation_token WHERE expires < ? """ - return txn.execute(sql, (ts,)) + txn.execute(sql, (ts,)) - return self.db.runInteraction( + await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), ) - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: """Set the `deactivated` property for the provided user to the provided value. Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. """ - yield self.db.runInteraction( + await self.db_pool.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1510,7 +1464,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1519,9 +1473,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + txn.call_after(self.is_guest.invalidate, (user_id,)) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self): """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. @@ -1538,14 +1492,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) txn.execute(sql, []) - res = self.db.cursor_to_dict(txn) + res = self.db_pool.cursor_to_dict(txn) if res: for user in res: self.set_expiration_date_for_user_txn( txn, user["name"], use_delta=True ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) @@ -1569,9 +1523,32 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expiration_ts, ) - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, "account_validity", keyvalues={"user_id": user_id}, values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/databases/main/rejections.py index 27e5a2084a..1e361aaa9a 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/databases/main/relations.py index 7d477f8d01..5cd61547f7 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,56 +14,53 @@ # limitations under the License. import logging +from typing import Optional import attr from synapse.api.constants import RelationTypes +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) - def get_relations_for_event( + async def get_relations_for_event( self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[RelationPaginationToken] = None, + to_token: Optional[RelationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. + event_id: Fetch events that relate to this event ID. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. + List of event IDs that match relations requested. The rows are of + the form `{"event_id": "..."}`. """ where_clause = ["relates_to_id = ?"] @@ -129,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @cached(tree=True) - def get_aggregation_groups_for_event( + async def get_aggregation_groups_for_event( self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + event_type: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[AggregationPaginationToken] = None, + to_token: Optional[AggregationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -150,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore): on an event. Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or + event_id: Fetch events that relate to this event ID. + event_type: Only fetch events with this event type, if given. + limit: Only fetch the `limit` groups. + direction: Whether to fetch the highest count first (`"b"`) or the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. + List of groups of annotations that match. Each row is a dict with + `type`, `key` and `count` fields. """ where_clause = ["relates_to_id = ?", "relation_type = ?"] @@ -223,22 +216,22 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): + @cached() + async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. Correctly handles checking whether edits were allowed to happen. Args: - event_id (str): The original event ID + event_id: The original event ID Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. + The most recent edit, if any. """ # We only allow edits for `m.room.message` events that have the same sender @@ -268,28 +261,29 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.db.runInteraction( + edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) if not edit_id: - return + return None - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event + return await self.get_event(edit_id, allow_none=True) - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + async def has_user_annotated_event( + self, parent_id: str, event_type: str, aggregation_key: str, sender: str + ) -> bool: """Check if a user has already annotated an event with the same key (e.g. already liked an event). Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation + parent_id: The event being annotated + event_type: The event type of the annotation + aggregation_key: The aggregation key of the annotation + sender: The sender of the annotation Returns: - Deferred[bool] + True if the event is already annotated. """ sql = """ @@ -318,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/databases/main/room.py index 46f643c6b9..717df97301 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -21,26 +21,19 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from canonicaljson import json - -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.search import SearchStore -from synapse.storage.database import Database, LoggingTransaction -from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.search import SearchStore +from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - RatelimitOverride = collections.namedtuple( "RatelimitOverride", ("messages_per_second", "burst_count") ) @@ -75,20 +68,20 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -96,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -118,30 +111,39 @@ class RoomWorkerStore(SQLBaseStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - res = self.db.cursor_to_dict(txn)[0] + # Catch error if sql returns empty result to return "None" instead of an error + try: + res = self.db_pool.cursor_to_dict(txn)[0] + except IndexError: + return None + res["federatable"] = bool(res["federatable"]) res["public"] = bool(res["public"]) return res - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) - def get_public_room_ids(self): - return self.db.simple_select_onecol( + async def get_public_room_ids(self) -> List[str]: + return await self.db_pool.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -185,10 +187,11 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) + return await self.db_pool.runInteraction( + "count_public_rooms", _count_public_rooms_txn + ) - @defer.inlineCallbacks - def get_largest_public_rooms( + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], @@ -318,21 +321,21 @@ class RoomWorkerStore(SQLBaseStore): def _get_largest_public_rooms_txn(txn): txn.execute(sql, query_args) - results = self.db.cursor_to_dict(txn) + results = self.db_pool.cursor_to_dict(txn) if not forwards: results.reverse() return results - ret_val = yield self.db.runInteraction( + ret_val = await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - defer.returnValue(ret_val) + return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -500,12 +503,12 @@ class RoomWorkerStore(SQLBaseStore): room_count = txn.fetchone() return rooms, room_count[0] - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_paginate", _get_rooms_paginate_txn, ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): + @cached(max_entries=10000) + async def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given user @@ -517,7 +520,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.db.simple_select_one( + row = await self.db_pool.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -533,8 +536,8 @@ class RoomWorkerStore(SQLBaseStore): else: return None - @cachedInlineCallbacks() - def get_retention_policy_for_room(self, room_id): + @cached() + async def get_retention_policy_for_room(self, room_id): """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -559,21 +562,19 @@ class RoomWorkerStore(SQLBaseStore): (room_id,), ) - return self.db.cursor_to_dict(txn) + return self.db_pool.cursor_to_dict(txn) - ret = yield self.db.runInteraction( + ret = await self.db_pool.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue( - { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - } - ) + return { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } row = ret[0] @@ -587,17 +588,16 @@ class RoomWorkerStore(SQLBaseStore): if row["max_lifetime"] is None: row["max_lifetime"] = self.config.retention_default_max_lifetime - defer.returnValue(row) + return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -613,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -626,37 +628,11 @@ class RoomWorkerStore(SQLBaseStore): def _quarantine_media_in_room_txn(txn): local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - total_media_quarantined = 0 - - # Now update all the tables to set the quarantined_by flag - - txn.executemany( - """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """, - ((quarantined_by, media_id) for media_id in local_mxcs), - ) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_mxcs - ), + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by ) - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) - - return total_media_quarantined - - return self.db.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -691,7 +667,7 @@ class RoomWorkerStore(SQLBaseStore): next_token = None for stream_ordering, content_json in txn: next_token = stream_ordering - event_json = json.loads(content_json) + event_json = db_to_json(content_json) content = event_json["content"] content_url = content.get("url") thumbnail_url = content.get("info", {}).get("thumbnail_url") @@ -719,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -740,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -756,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -805,17 +783,17 @@ class RoomWorkerStore(SQLBaseStore): Returns: The total number of media items quarantined """ - total_media_quarantined = 0 - # Update all the tables to set the quarantined_by flag txn.executemany( """ UPDATE local_media_repository SET quarantined_by = ? - WHERE media_id = ? + WHERE media_id = ? AND safe_from_quarantine = ? """, - ((quarantined_by, media_id) for media_id in local_mxcs), + ((quarantined_by, media_id, False) for media_id in local_mxcs), ) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 txn.executemany( """ @@ -825,13 +803,36 @@ class RoomWorkerStore(SQLBaseStore): """, ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), ) - - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 return total_media_quarantined - def get_all_new_public_rooms(self, prev_id, current_id, limit): + async def get_all_new_public_rooms( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for public rooms replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: + return [], current_id, False + def get_all_new_public_rooms(txn): sql = """ SELECT stream_id, room_id, visibility, appservice_id, network_id @@ -841,13 +842,17 @@ class RoomWorkerStore(SQLBaseStore): LIMIT ? """ - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - if prev_id == current_id: - return defer.succeed([]) + return updates, upto_token, limited - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_public_rooms", get_all_new_public_rooms ) @@ -856,27 +861,26 @@ class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.config = hs.config - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, self._remove_tombstoned_rooms_from_directory, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.ADD_ROOMS_ROOM_VERSION_COLUMN, self._background_add_rooms_room_version_column, ) - @defer.inlineCallbacks - def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -900,7 +904,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): (last_room, batch_size), ) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return True @@ -909,10 +913,10 @@ class RoomBackgroundUpdateStore(SQLBaseStore): if not row["json"]: retention_policy = {} else: - ev = json.loads(row["json"]) - retention_policy = json.dumps(ev["content"]) + ev = db_to_json(row["json"]) + retention_policy = ev["content"] - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="room_retention", values={ @@ -925,7 +929,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): logger.info("Inserted %d rows into room_retention", len(rows)) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) @@ -934,14 +938,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore): else: return False - end = yield self.db.runInteraction( + end = await self.db_pool.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - yield self.db.updates._end_background_update("insert_room_retention") + await self.db_pool.updates._end_background_update("insert_room_retention") - defer.returnValue(batch_size) + return batch_size async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int @@ -965,7 +969,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): updates = [] for room_id, event_json in txn: - event_dict = json.loads(event_json) + event_dict = db_to_json(event_json) room_version_id = event_dict.get("content", {}).get( "room_version", RoomVersions.V1.identifier ) @@ -983,7 +987,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): # mainly for paranoia as much badness would happen if we don't # insert the row and then try and get the room version for the # room. - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="rooms", keyvalues={"room_id": room_id}, @@ -992,19 +996,19 @@ class RoomBackgroundUpdateStore(SQLBaseStore): ) new_last_room_id = room_id - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} ) return False - end = await self.db.runInteraction( + end = await self.db_pool.runInteraction( "_background_add_rooms_room_version_column", _background_add_rooms_room_version_column_txn, ) if end: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.ADD_ROOMS_ROOM_VERSION_COLUMN ) @@ -1038,12 +1042,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return [row[0] for row in txn] - rooms = await self.db.runInteraction( + rooms = await self.db_pool.runInteraction( "get_tombstoned_directory_rooms", _get_rooms ) if not rooms: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE ) return 0 @@ -1052,7 +1056,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): logger.info("Removing tombstoned room %s from the directory", room_id) await self.set_room_is_public(room_id, False) - await self.db.updates._background_update_progress( + await self.db_pool.updates._background_update_progress( self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} ) @@ -1068,7 +1072,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomStore, self).__init__(database, db_conn, hs) self.config = hs.config @@ -1079,7 +1083,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Called when we join a room over federation, and overwrites any room version currently in the table. """ - await self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, @@ -1090,8 +1094,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def store_room( + async def store_room( self, room_id: str, room_creator_user_id: str, @@ -1112,7 +1115,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, "rooms", { @@ -1123,7 +1126,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) if is_public: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -1133,8 +1136,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) + with await self._public_room_id_gen.get_next() as next_id: + await self.db_pool.runInteraction( + "store_room_txn", store_room_txn, next_id + ) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1144,7 +1149,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): When we receive an invite over federation, store the version of the room if we don't already know the room version. """ - await self.db.simple_upsert( + await self.db_pool.simple_upsert( desc="maybe_store_room_on_invite", table="rooms", keyvalues={"room_id": room_id}, @@ -1159,17 +1164,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): + async def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self.db.simple_select_list_txn( + entries = self.db_pool.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -1187,7 +1191,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -1199,14 +1203,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + with await self._public_room_id_gen.get_next() as next_id: + await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() - @defer.inlineCallbacks - def set_room_is_public_appservice( + async def set_room_is_public_appservice( self, room_id, appservice_id, network_id, is_public ): """Edit the appservice/network specific public room list. @@ -1227,7 +1230,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -1240,7 +1243,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -1250,7 +1253,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - entries = self.db.simple_select_list_txn( + entries = self.db_pool.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -1268,7 +1271,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -1280,16 +1283,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + with await self._public_room_id_gen.get_next() as next_id: + await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1298,13 +1301,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): + async def add_event_report( + self, + room_id: str, + event_id: str, + user_id: str, + reason: str, + content: JsonDict, + received_ts: int, + ) -> None: next_id = self._event_reports_id_gen.get_next() - return self.db.simple_insert( + await self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, @@ -1313,7 +1322,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "event_id": event_id, "user_id": user_id, "reason": reason, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, desc="add_event_report", ) @@ -1321,52 +1330,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @defer.inlineCallbacks - def block_room(self, room_id, user_id): + async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred + room_id: Room to block + user_id: Who blocked it """ - yield self.db.simple_upsert( + await self.db_pool.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.db.runInteraction( + await self.db_pool.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, (room_id,), ) - @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range( - self, min_ms, max_ms, include_null=False - ): + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, doesn't set an upper limit. - include_null (bool): Whether to include rooms which retention policy is NULL + include_null: Whether to include rooms which retention policy is NULL in the returned set. Returns: - dict[str, dict]: The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). """ def get_rooms_for_retention_period_in_range_txn(txn): @@ -1396,7 +1400,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): txn.execute(sql, args) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) rooms_dict = {} for row in rows: @@ -1412,7 +1416,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): txn.execute(sql) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) # If a room isn't already in the dict (i.e. it doesn't have a retention # policy in its state), add it with a null policy. @@ -1425,9 +1429,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.db.runInteraction( + rooms = await self.db_pool.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) - defer.returnValue(rooms) + return rooms diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/databases/main/roommember.py index 137ebac833..91a8b43da3 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,24 +15,21 @@ # limitations under the License. import logging -from typing import Iterable, List, Set - -from six import iteritems, itervalues - -from canonicaljson import json - -from twisted.internet import defer +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import ( LoggingTransaction, SQLBaseStore, + db_to_json, make_in_list_sql_clause, ) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -43,9 +40,12 @@ from synapse.storage.roommember import ( from synapse.types import Collection, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.state import _StateCacheEntry + logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the @@ -90,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) - @defer.inlineCallbacks - def _count_known_servers(self): + async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -119,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.db.runInteraction("get_known_servers", _transact) + count = await self.db_pool.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -131,7 +130,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self.db.simple_select_one_txn( + pending_update = self.db_pool.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -147,18 +146,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): 15.0, run_as_background_process, "_check_safe_current_state_events_membership_updated", - self.db.runInteraction, + self.db_pool.runInteraction, "_check_safe_current_state_events_membership_updated", self._check_safe_current_state_events_membership_updated_txn, ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): - return self.db.runInteraction( + async def get_users_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id): + def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -181,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id): + async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: - room_id (str): The room ID to query + room_id: The room ID to query Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. + dict of membership states, pointing to a MemberSummary named tuple. """ def _get_room_summary_txn(txn): @@ -262,80 +260,63 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.db.runInteraction("get_room_summary", _get_room_summary_txn) - - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} + return await self.db_pool.runInteraction( + "get_room_summary", _get_room_summary_txn + ) @cached() - def get_invited_rooms_for_local_user(self, user_id): - """ Get all the rooms the *local* user is invited to + async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: + """Get all the rooms the *local* user is invited to. Args: - user_id (str): The user ID. + user_id: The user ID. + Returns: - A deferred list of RoomsForUser. + A list of RoomsForUser. """ - return self.get_rooms_for_local_user_where_membership_is( + return await self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) - @defer.inlineCallbacks - def get_invite_for_local_user_in_room(self, user_id, room_id): - """Gets the invite for the given *local* user and room + async def get_invite_for_local_user_in_room( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUser]: + """Gets the invite for the given *local* user and room. Args: - user_id (str) - room_id (str) + user_id: The user ID to find the invite of. + room_id: The room to user was invited to. Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. + Either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_local_user(user_id) + invites = await self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None - @defer.inlineCallbacks - def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this *local* user where the membership for this user + async def get_rooms_for_local_user_where_membership_is( + self, user_id: str, membership_list: Collection[str] + ) -> List[RoomsForUser]: + """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. + user_id: The user ID. + membership_list: A list of synapse.api.constants.Membership + values which the user must be in. Returns: - Deferred[list[RoomsForUser]] + The RoomsForUser that the user matches the membership types. """ if not membership_list: - return defer.succeed(None) + return [] - rooms = yield self.db.runInteraction( + rooms = await self.db_pool.runInteraction( "get_rooms_for_local_user_where_membership_is", self._get_rooms_for_local_user_where_membership_is_txn, user_id, @@ -343,12 +324,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): + self, txn, user_id: str, membership_list: List[str] + ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): raise Exception( @@ -372,32 +353,35 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] + results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): + async def get_rooms_for_user_with_stream_ordering( + self, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. Args: - user_id (str) + user_id Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. + Returns the rooms the user is in currently, along with the stream + ordering of the most recent join for that user and room. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", self._get_rooms_for_user_with_stream_ordering_txn, user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + def _get_rooms_for_user_with_stream_ordering_txn( + self, txn, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -424,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) - - return results + return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -456,42 +438,44 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0] for row in txn} - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "get_users_server_still_shares_room_with", _get_users_server_still_shares_room_with_txn, ) - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): + async def get_rooms_for_user(self, user_id: str, on_invalidate=None): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( + rooms = await self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate ) return frozenset(r.room_id for r in rooms) - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): + @cached(max_entries=500000, cache_context=True, iterable=True) + async def get_users_who_share_room_with_user( + self, user_id: str, cache_context: _CacheContext + ) -> Set[str]: """Returns the set of users who share a room with `user_id` """ - room_ids = yield self.get_rooms_for_user( + room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate ) user_who_share_room = set() for room_id in room_ids: - user_ids = yield self.get_users_in_room( + user_ids = await self.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) user_who_share_room.update(user_ids) return user_who_share_room - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): + async def get_joined_users_from_context( + self, event: EventBase, context: EventContext + ): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -500,14 +484,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() - result = yield self._get_joined_users_from_context( + current_state_ids = await context.get_current_state_ids() + return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) - return result - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -517,16 +499,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + return await self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry ) - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( + @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) + async def _get_joined_users_from_context( self, room_id, state_group, @@ -538,13 +516,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None users_in_room = {} member_event_ids = [ e_id - for key, e_id in iteritems(current_state_ids) + for key, e_id in current_state_ids.items() if key[0] == EventTypes.Member ] @@ -561,7 +538,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): users_in_room = dict(prev_res) member_event_ids = [ e_id - for key, e_id in iteritems(context.delta_ids) + for key, e_id in context.delta_ids.items() if key[0] == EventTypes.Member ] for etype, state_key in context.delta_ids: @@ -591,7 +568,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) users_in_room.update((row for row in event_to_memberships.values() if row)) @@ -611,23 +588,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, + cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - def _get_joined_profiles_from_event_ids(self, event_ids): + async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. Args: - event_ids (Iterable[str]): The member event IDs to lookup + event_ids: The member event IDs to lookup Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -647,8 +622,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): for row in rows } - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): + @cached(max_entries=10000) + async def is_host_joined(self, room_id: str, host: str) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") @@ -667,47 +642,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) + rows = await self.db_pool.execute( + "is_host_joined", None, sql, room_id, like_clause + ) if not rows: return False @@ -719,8 +656,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): + async def get_joined_hosts(self, room_id: str, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -730,32 +666,28 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) + return await self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry ) - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + @cached(num_args=2, max_entries=10000, iterable=True) + async def _get_joined_hosts( + self, room_id, state_group, current_state_ids, state_entry + ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = yield self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts + cache = await self._get_joined_hosts_cache(room_id) + return await cache.get_destinations(state_entry) @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): + def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache(self, room_id) - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): + @cached(num_args=2) + async def did_forget(self, user_id: str, room_id: str) -> bool: """Returns whether user_id has elected to discard history for room_id. Returns False if they have since re-joined.""" @@ -777,18 +709,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = txn.fetchall() return rows[0][0] - count = yield self.db.runInteraction("did_forget_membership", f) + count = await self.db_pool.runInteraction("did_forget_membership", f) return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id): + async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: """Gets all rooms the user has forgotten. Args: - user_id (str) + user_id: The user ID to query the rooms of. Returns: - Deferred[set[str]] + The forgotten rooms. """ def _get_forgotten_rooms_for_user_txn(txn): @@ -814,22 +746,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return {row[0] for row in txn if row[1] == 0} - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. Args: - user_id (str) + user_id: The user ID to get the rooms of. Returns: - Deferred[set[str]]: Set of room IDs. + Set of room IDs. """ - room_ids = yield self.db.simple_select_onecol( + room_ids = await self.db_pool.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -838,13 +769,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -880,23 +811,23 @@ class RoomMemberWorkerStore(EventsWorkerStore): return bool(txn.fetchone()) - return await self.db.runInteraction( + return await self.db_pool.runInteraction( "is_local_host_in_room_ignoring_users", _is_local_host_in_room_ignoring_users_txn, ) class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -904,8 +835,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): where_clause="forgotten = 1", ) - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( "target_min_stream_id_inclusive", self._min_stream_order_on_start ) @@ -929,7 +859,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return 0 @@ -940,7 +870,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): event_id = row["event_id"] room_id = row["room_id"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue @@ -964,25 +894,24 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): "max_stream_id_exclusive": min_stream_id, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) return len(rows) - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) if not result: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( _MEMBERSHIP_PROFILE_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership(self, progress, batch_size): """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. @@ -1016,7 +945,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): last_processed_room = next_room - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -1028,14 +957,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.db.runInteraction( + row_count, finished = await self.db_pool.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, ) if finished: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME ) @@ -1043,10 +972,10 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id, room_id): + async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -1067,10 +996,10 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.db.runInteraction("forget_membership", f) + await self.db_pool.runInteraction("forget_membership", f) -class _JoinedHostsCache(object): +class _JoinedHostsCache: """Cache for joined hosts in a room that is optimised to handle updates via state deltas. """ @@ -1087,21 +1016,23 @@ class _JoinedHostsCache(object): self._len = 0 - @defer.inlineCallbacks - def get_destinations(self, state_entry): + async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: """Get set of destinations for a state entry Args: - state_entry(synapse.state._StateCacheEntry) + state_entry + + Returns: + The destinations as a set. """ if state_entry.state_group == self.state_group: return frozenset(self.hosts_to_joined_users) - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in iteritems(state_entry.delta_ids): + for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: continue @@ -1109,7 +1040,7 @@ class _JoinedHostsCache(object): user_id = state_key known_joins = self.hosts_to_joined_users.setdefault(host, set()) - event = yield self.store.get_event(event_id) + event = await self.store.get_event(event_id) if event.membership == Membership.JOIN: known_joins.add(user_id) else: @@ -1118,7 +1049,7 @@ class _JoinedHostsCache(object): if not known_joins: self.hosts_to_joined_users.pop(host, None) else: - joined_users = yield self.store.get_joined_users_from_state( + joined_users = await self.store.get_joined_users_from_state( self.room_id, state_entry ) @@ -1131,7 +1062,7 @@ class _JoinedHostsCache(object): self.state_group = state_entry.state_group else: self.state_group = object() - self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) + self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) return frozenset(self.hosts_to_joined_users) def __len__(self): diff --git a/synapse/storage/data_stores/main/schema/delta/12/v12.sql b/synapse/storage/databases/main/schema/delta/12/v12.sql index 5964c5aaac..5964c5aaac 100644 --- a/synapse/storage/data_stores/main/schema/delta/12/v12.sql +++ b/synapse/storage/databases/main/schema/delta/12/v12.sql diff --git a/synapse/storage/data_stores/main/schema/delta/13/v13.sql b/synapse/storage/databases/main/schema/delta/13/v13.sql index f8649e5d99..f8649e5d99 100644 --- a/synapse/storage/data_stores/main/schema/delta/13/v13.sql +++ b/synapse/storage/databases/main/schema/delta/13/v13.sql diff --git a/synapse/storage/data_stores/main/schema/delta/14/v14.sql b/synapse/storage/databases/main/schema/delta/14/v14.sql index a831920da6..a831920da6 100644 --- a/synapse/storage/data_stores/main/schema/delta/14/v14.sql +++ b/synapse/storage/databases/main/schema/delta/14/v14.sql diff --git a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql index e4f5e76aec..e4f5e76aec 100644 --- a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql +++ b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql diff --git a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql index 6b8d0f1ca7..6b8d0f1ca7 100644 --- a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql +++ b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/15/v15.sql b/synapse/storage/databases/main/schema/delta/15/v15.sql index 9523d2bcc3..9523d2bcc3 100644 --- a/synapse/storage/data_stores/main/schema/delta/15/v15.sql +++ b/synapse/storage/databases/main/schema/delta/15/v15.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql index a48f215170..a48f215170 100644 --- a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql +++ b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql index 7a15265cb1..7a15265cb1 100644 --- a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql +++ b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql index 65c97b5e2f..65c97b5e2f 100644 --- a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql +++ b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql index f82486132b..f82486132b 100644 --- a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql +++ b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql index 5b8de52c33..5b8de52c33 100644 --- a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql +++ b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql diff --git a/synapse/storage/data_stores/main/schema/delta/16/users.sql b/synapse/storage/databases/main/schema/delta/16/users.sql index cd0709250d..cd0709250d 100644 --- a/synapse/storage/data_stores/main/schema/delta/16/users.sql +++ b/synapse/storage/databases/main/schema/delta/16/users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql index 7c9a90e27f..7c9a90e27f 100644 --- a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql +++ b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql diff --git a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql b/synapse/storage/databases/main/schema/delta/17/server_keys.sql index 70b247a06b..70b247a06b 100644 --- a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql +++ b/synapse/storage/databases/main/schema/delta/17/server_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql index c17715ac80..c17715ac80 100644 --- a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql +++ b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql index 6e0871c92b..6e0871c92b 100644 --- a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql +++ b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql diff --git a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql b/synapse/storage/databases/main/schema/delta/19/event_index.sql index 18b97b4332..18b97b4332 100644 --- a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql +++ b/synapse/storage/databases/main/schema/delta/19/event_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql b/synapse/storage/databases/main/schema/delta/20/dummy.sql index e0ac49d1ec..e0ac49d1ec 100644 --- a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql +++ b/synapse/storage/databases/main/schema/delta/20/dummy.sql diff --git a/synapse/storage/data_stores/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py index 3edfcfd783..3edfcfd783 100644 --- a/synapse/storage/data_stores/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py diff --git a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql index 4c2fb20b77..4c2fb20b77 100644 --- a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql +++ b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql b/synapse/storage/databases/main/schema/delta/21/receipts.sql index d070845477..d070845477 100644 --- a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql +++ b/synapse/storage/databases/main/schema/delta/21/receipts.sql diff --git a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql index bfc0b3bcaa..bfc0b3bcaa 100644 --- a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql +++ b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql index 87edfa454c..87edfa454c 100644 --- a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql +++ b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql diff --git a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql index acea7483bd..acea7483bd 100644 --- a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql +++ b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py index 4b2ffd35fd..ee675e71ff 100644 --- a/synapse/storage/data_stores/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.prepare_database import get_statements @@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql b/synapse/storage/databases/main/schema/delta/25/guest_access.sql index 1ea389b471..1ea389b471 100644 --- a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql +++ b/synapse/storage/databases/main/schema/delta/25/guest_access.sql diff --git a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql index f468fc1897..f468fc1897 100644 --- a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql +++ b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql diff --git a/synapse/storage/data_stores/main/schema/delta/25/tags.sql b/synapse/storage/databases/main/schema/delta/25/tags.sql index 7a32ce68e4..7a32ce68e4 100644 --- a/synapse/storage/data_stores/main/schema/delta/25/tags.sql +++ b/synapse/storage/databases/main/schema/delta/25/tags.sql diff --git a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql b/synapse/storage/databases/main/schema/delta/26/account_data.sql index e395de2b5e..e395de2b5e 100644 --- a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql +++ b/synapse/storage/databases/main/schema/delta/26/account_data.sql diff --git a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql b/synapse/storage/databases/main/schema/delta/27/account_data.sql index bf0558b5b3..bf0558b5b3 100644 --- a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql +++ b/synapse/storage/databases/main/schema/delta/27/account_data.sql diff --git a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql index e2094f37fe..e2094f37fe 100644 --- a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql +++ b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py index 414f9f5aa0..b7972cfa8e 100644 --- a/synapse/storage/data_stores/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql index 4d519849df..4d519849df 100644 --- a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql +++ b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql index 36609475f1..36609475f1 100644 --- a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql +++ b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql index 6c1fd68c5b..6c1fd68c5b 100644 --- a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql +++ b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql index cb84c69baa..cb84c69baa 100644 --- a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql +++ b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql index 3e4a9ab455..3e4a9ab455 100644 --- a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql +++ b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql diff --git a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql index 21d2b420bf..21d2b420bf 100644 --- a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql +++ b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql diff --git a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql b/synapse/storage/databases/main/schema/delta/29/push_actions.sql index 84b21cf813..84b21cf813 100644 --- a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql +++ b/synapse/storage/databases/main/schema/delta/29/push_actions.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql index c9d0dde638..c9d0dde638 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql +++ b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py index 9b95411fb6..b42c02710a 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -13,8 +13,6 @@ # limitations under the License. import logging -from six.moves import range - from synapse.config.appservice import load_appservices logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql index 712c454aa1..712c454aa1 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql +++ b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql index 606bbb037d..606bbb037d 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql +++ b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql index f09db4faa6..f09db4faa6 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql +++ b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql index 735aa8d5f6..735aa8d5f6 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql +++ b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql index 0dd2f1360c..0dd2f1360c 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql +++ b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/invites.sql b/synapse/storage/databases/main/schema/delta/31/invites.sql index 2c57846d5a..2c57846d5a 100644 --- a/synapse/storage/data_stores/main/schema/delta/31/invites.sql +++ b/synapse/storage/databases/main/schema/delta/31/invites.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql index 9efb4280eb..9efb4280eb 100644 --- a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql +++ b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py index 9bb504aad5..9bb504aad5 100644 --- a/synapse/storage/data_stores/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql index a82add88fd..a82add88fd 100644 --- a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql +++ b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py index 7d8ca5f93f..63b757ade6 100644 --- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements @@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs): "rows_inserted": 0, "have_added_indexes": False, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/32/events.sql b/synapse/storage/databases/main/schema/delta/32/events.sql index 1dd0f9e170..1dd0f9e170 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/events.sql +++ b/synapse/storage/databases/main/schema/delta/32/events.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/openid.sql b/synapse/storage/databases/main/schema/delta/32/openid.sql index 36f37b11c8..36f37b11c8 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/openid.sql +++ b/synapse/storage/databases/main/schema/delta/32/openid.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql index d86d30c13c..d86d30c13c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql +++ b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql index 2de50d408c..2de50d408c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/32/reports.sql b/synapse/storage/databases/main/schema/delta/32/reports.sql index d13609776f..d13609776f 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/reports.sql +++ b/synapse/storage/databases/main/schema/delta/32/reports.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql index 61ad3fe3e8..61ad3fe3e8 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql +++ b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices.sql b/synapse/storage/databases/main/schema/delta/33/devices.sql index eca7268d82..eca7268d82 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/devices.sql +++ b/synapse/storage/databases/main/schema/delta/33/devices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql index aa4a3b9f2f..aa4a3b9f2f 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql +++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql index 6671573398..6671573398 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql +++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py index bff1256a7b..a3e81eeac7 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index a26057dfb6..a26057dfb6 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py diff --git a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql index 473f75a78e..473f75a78e 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql +++ b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql index 69e16eda0f..69e16eda0f 100644 --- a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql +++ b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py b/synapse/storage/databases/main/schema/delta/34/cache_stream.py index cf09e43e2b..cf09e43e2b 100644 --- a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py +++ b/synapse/storage/databases/main/schema/delta/34/cache_stream.py diff --git a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql index e68844c74a..e68844c74a 100644 --- a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql +++ b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql index 0d9fe1a99a..0d9fe1a99a 100644 --- a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql +++ b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql diff --git a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py index 67d505e68b..67d505e68b 100644 --- a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py +++ b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py diff --git a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql b/synapse/storage/databases/main/schema/delta/35/contains_url.sql index 6cd123027b..6cd123027b 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql +++ b/synapse/storage/databases/main/schema/delta/35/contains_url.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql index 17e6c43105..17e6c43105 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql +++ b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql index 7ab7d942e2..7ab7d942e2 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql +++ b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql index 2e836d8e9c..2e836d8e9c 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql +++ b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql index dd2bf2e28a..dd2bf2e28a 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql +++ b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql index 2b945d8a57..2b945d8a57 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql +++ b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql diff --git a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql index 90d8fd18f9..90d8fd18f9 100644 --- a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql +++ b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py index a377884169..a377884169 100644 --- a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py diff --git a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql index cf7a90dd10..cf7a90dd10 100644 --- a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql +++ b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql index 515e6b8e84..515e6b8e84 100644 --- a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql +++ b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql index 74bdc49073..74bdc49073 100644 --- a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql +++ b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql index 00be801e90..00be801e90 100644 --- a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql +++ b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql index de2ad93e5c..de2ad93e5c 100644 --- a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql +++ b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql index 5af814290b..5af814290b 100644 --- a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql +++ b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql diff --git a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql index 1bf911c8ab..1bf911c8ab 100644 --- a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql +++ b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql index 7ffa189f39..7ffa189f39 100644 --- a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql +++ b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql index b9fe1f0480..b9fe1f0480 100644 --- a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql +++ b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql index dd6dcb65f1..dd6dcb65f1 100644 --- a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql index 3918f0b794..3918f0b794 100644 --- a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql +++ b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql diff --git a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql b/synapse/storage/databases/main/schema/delta/40/pushers.sql index 054a223f14..054a223f14 100644 --- a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql +++ b/synapse/storage/databases/main/schema/delta/40/pushers.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql index b7bee8b692..b7bee8b692 100644 --- a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql +++ b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql index 62f0b9892b..62f0b9892b 100644 --- a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql +++ b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql index 5d9cfecf36..5d9cfecf36 100644 --- a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql +++ b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql index a194bf0238..a194bf0238 100644 --- a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql +++ b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql index d28851aff8..d28851aff8 100644 --- a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql +++ b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql index 9ab8c14fa3..9ab8c14fa3 100644 --- a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql +++ b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql index b8821ac759..b8821ac759 100644 --- a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql +++ b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql diff --git a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py b/synapse/storage/databases/main/schema/delta/42/user_dir.py index 506f326f4d..506f326f4d 100644 --- a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py +++ b/synapse/storage/databases/main/schema/delta/42/user_dir.py diff --git a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql index 0e3cd143ff..0e3cd143ff 100644 --- a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql +++ b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql index 630907ec4f..630907ec4f 100644 --- a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql +++ b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql diff --git a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql b/synapse/storage/databases/main/schema/delta/43/url_cache.sql index 45ebe020da..45ebe020da 100644 --- a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql +++ b/synapse/storage/databases/main/schema/delta/43/url_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql b/synapse/storage/databases/main/schema/delta/43/user_share.sql index ee7062abe4..ee7062abe4 100644 --- a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql +++ b/synapse/storage/databases/main/schema/delta/43/user_share.sql diff --git a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql index b12f9b2ebf..b12f9b2ebf 100644 --- a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql +++ b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql b/synapse/storage/databases/main/schema/delta/45/group_server.sql index b2333848a0..b2333848a0 100644 --- a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql +++ b/synapse/storage/databases/main/schema/delta/45/group_server.sql diff --git a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql index e5ddc84df0..e5ddc84df0 100644 --- a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql +++ b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql index 68c48a89a9..68c48a89a9 100644 --- a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql +++ b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql index bb307889c1..bb307889c1 100644 --- a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql +++ b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql b/synapse/storage/databases/main/schema/delta/46/group_server.sql index 097679bc9a..097679bc9a 100644 --- a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql +++ b/synapse/storage/databases/main/schema/delta/46/group_server.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql index bbfc7f5d1a..bbfc7f5d1a 100644 --- a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql +++ b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql index cb0d5a2576..cb0d5a2576 100644 --- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql +++ b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql index d9505f8da1..d9505f8da1 100644 --- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql +++ b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql index f505fb22b5..f505fb22b5 100644 --- a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql +++ b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql index 31d7a817eb..31d7a817eb 100644 --- a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql +++ b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql index edccf4a96f..edccf4a96f 100644 --- a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql +++ b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql index 5237491506..5237491506 100644 --- a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql +++ b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql index 9248b0b24a..9248b0b24a 100644 --- a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql +++ b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql index e9013a6969..e9013a6969 100644 --- a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql +++ b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py index 49f5f2c003..49f5f2c003 100644 --- a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py diff --git a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql index ce26eaf0c9..ce26eaf0c9 100644 --- a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql +++ b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql index 14dcf18d73..14dcf18d73 100644 --- a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql +++ b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql index 3dd478196f..3dd478196f 100644 --- a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql +++ b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql index 3a4ed59b5b..3a4ed59b5b 100644 --- a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql +++ b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql index c93ae47532..c93ae47532 100644 --- a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql +++ b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql index 5d8641a9ab..5d8641a9ab 100644 --- a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql +++ b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql diff --git a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py index b1684a8441..b1684a8441 100644 --- a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py diff --git a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql index c0e66a697d..c0e66a697d 100644 --- a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql +++ b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql index c9d537d5a3..c9d537d5a3 100644 --- a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql index 91e03d13e1..91e03d13e1 100644 --- a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql +++ b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql index bfa49e6f92..bfa49e6f92 100644 --- a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql +++ b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql index db687cccae..db687cccae 100644 --- a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql +++ b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql index 88ec2f83e5..88ec2f83e5 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql +++ b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql index e372f5a44a..e372f5a44a 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql +++ b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql index 1d977c2834..1d977c2834 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql +++ b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql index ffcc896b58..ffcc896b58 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql +++ b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql index b812c5794f..b812c5794f 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql +++ b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql b/synapse/storage/databases/main/schema/delta/53/user_share.sql index 5831b1a6f8..5831b1a6f8 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql +++ b/synapse/storage/databases/main/schema/delta/53/user_share.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql index 80c2c573b6..80c2c573b6 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql +++ b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql diff --git a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql index f7827ca6d2..f7827ca6d2 100644 --- a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql +++ b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql index 0adb2ad55e..0adb2ad55e 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql +++ b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql index c01aa9d2d9..c01aa9d2d9 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql +++ b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql index b062ec840c..b062ec840c 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql index dbbe682697..dbbe682697 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql +++ b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql index e6ee70c623..e6ee70c623 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql +++ b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/relations.sql b/synapse/storage/databases/main/schema/delta/54/relations.sql index 134862b870..134862b870 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/relations.sql +++ b/synapse/storage/databases/main/schema/delta/54/relations.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats.sql b/synapse/storage/databases/main/schema/delta/54/stats.sql index 652e58308e..652e58308e 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/stats.sql +++ b/synapse/storage/databases/main/schema/delta/54/stats.sql diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql b/synapse/storage/databases/main/schema/delta/54/stats2.sql index 3b2d48447f..3b2d48447f 100644 --- a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql +++ b/synapse/storage/databases/main/schema/delta/54/stats2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql index 4590604bfd..4590604bfd 100644 --- a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql +++ b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql diff --git a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql index a8eced2e0a..a8eced2e0a 100644 --- a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql +++ b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql diff --git a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql index dabdde489b..dabdde489b 100644 --- a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql +++ b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql index 41807eb1e7..41807eb1e7 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql +++ b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql index 473018676f..473018676f 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql +++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql index 3133d42d4a..3133d42d4a 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql +++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql index 1d2ddb1b1a..1d2ddb1b1a 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql +++ b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql index f00889290b..f00889290b 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql +++ b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres index b9bbb18a91..b9bbb18a91 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql index c2f557fde9..c2f557fde9 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql +++ b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql index dfa902d0ba..dfa902d0ba 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql +++ b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql index 9f09922c67..9f09922c67 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql +++ b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql index 81a36a8b1d..81a36a8b1d 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..5e29c1da19 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql index 5f5e0499ae..5f5e0499ae 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql index 014cb3b538..014cb3b538 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql +++ b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql index 67f8b20297..67f8b20297 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql +++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite index e8b1fd35d8..e8b1fd35d8 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite +++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql index 4f24c1405d..4f24c1405d 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql +++ b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql index 7be31ffebb..7be31ffebb 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql +++ b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql index ea95db0ed7..ea95db0ed7 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql index 49ce35d794..49ce35d794 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres index 67471f3ef5..67471f3ef5 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql index b7550f6f4e..b7550f6f4e 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql +++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql index aeb17813d3..aeb17813d3 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql +++ b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql index 7d70dd071e..7d70dd071e 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql +++ b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql index 92ab1f5e65..92ab1f5e65 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql +++ b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/databases/main/schema/delta/56/room_retention.sql index ee6cdf7a14..ee6cdf7a14 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql +++ b/synapse/storage/databases/main/schema/delta/56/room_retention.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql index 5c5fffcafb..5c5fffcafb 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql +++ b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql index 0aa90ebf0c..0aa90ebf0c 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql +++ b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql index bbdde121e8..bbdde121e8 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql +++ b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..1de8b54961 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql index 91390c4527..91390c4527 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql +++ b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql index 149f8be8b6..149f8be8b6 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql +++ b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql index aec06c8261..aec06c8261 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql +++ b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql index c3b6de2099..c3b6de2099 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql +++ b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py index 63b5acdcf7..63b5acdcf7 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql index 133d80af35..133d80af35 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql +++ b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql index 352a66f5b0..352a66f5b0 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres index c601cff6de..c601cff6de 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite index 335c6f2074..335c6f2074 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres index 92aaadde0d..92aaadde0d 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite index e19dab97cb..e19dab97cb 100644 --- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite +++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql index fdc39e9ba5..fdc39e9ba5 100644 --- a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql +++ b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql index dcb593fc2d..dcb593fc2d 100644 --- a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql +++ b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres index aa46eb0e10..aa46eb0e10 100644 --- a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py index d353f2bcb3..d353f2bcb3 100644 --- a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py +++ b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres new file mode 100644 index 0000000000..597f2ffd3d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The local_media_repository should have files which do not get quarantined, +-- e.g. files from sticker packs. +ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite new file mode 100644 index 0000000000..69db89ac0e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The local_media_repository should have files which do not get quarantined, +-- e.g. files from sticker packs. +ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0; diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql new file mode 100644 index 0000000000..260b009b48 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A shadow-banned user may be told that their requests succeeded when they were +-- actually ignored. +ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN; diff --git a/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql new file mode 100644 index 0000000000..eb57203e46 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +The version of synapse 1.16.0 on pypi incorrectly contained a migration which +added a table called 'local_rejections_stream'. This table is not used, and +we drop it here for anyone who was affected. +*/ + +DROP TABLE IF EXISTS local_rejections_stream; diff --git a/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql new file mode 100644 index 0000000000..1cc2633aad --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We need to store the stream positions by instance in a sharded config world. +-- +-- We default to master as we want the column to be NOT NULL and we correctly +-- reset the instance name to match the config each time we start up. +ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; + +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py new file mode 100644 index 0000000000..4310ec12ce --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py @@ -0,0 +1,34 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adds a postgres SEQUENCE for generating guest user IDs. +""" + +from synapse.storage.databases.main.registration import ( + find_max_generated_user_id_localpart, +) +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if not isinstance(database_engine, PostgresEngine): + return + + next_id = find_max_generated_user_id_localpart(cur) + 1 + cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql new file mode 100644 index 0000000000..cade5dcca8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql @@ -0,0 +1,32 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Recalculate the stats for all rooms after the fix to joined_members erroneously +-- incrementing on per-room profile changes. + +-- Note that the populate_stats_process_rooms background update is already set to +-- run if you're upgrading from Synapse <1.0.0. + +-- Additionally, if you've upgraded to v1.18.0 (which doesn't include this fix), +-- this bg job runs, and then update to v1.19.0, you'd end up with only half of +-- your rooms having room stats recalculated after this fix was in place. + +-- So we've switched the old `populate_stats_process_rooms` background job to a +-- no-op, and then kick off a bg job with a new name, but with the same +-- functionality as the old one. This effectively restarts the background job +-- from the beginning, without running it twice in a row, supporting both +-- upgrade usecases. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql new file mode 100644 index 0000000000..15421b99ac --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This table is no longer used. +DROP TABLE IF EXISTS presence_allow_inbound; diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql new file mode 100644 index 0000000000..317fba8a5d --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We're hijacking the push actions to store unread messages and unread counts (specified +-- in MSC2654) because doing otherwise would result in either performance issues or +-- reimplementing a consequent bit of the push actions. + +-- Add columns to event_push_actions and event_push_actions_staging to track unread +-- messages and calculate unread counts. +ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT; +ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT; + +-- Add column to event_push_summary +ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT; \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql index 883fcd10b2..883fcd10b2 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql index 10ce2aa7a0..10ce2aa7a0 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql index 95826da431..95826da431 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql b/synapse/storage/databases/main/schema/full_schemas/16/im.sql index a1a2aa8e5b..a1a2aa8e5b 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/im.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql index 11cdffdbb3..11cdffdbb3 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql index 8f3759bb2a..8f3759bb2a 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql index 01d2d8f833..01d2d8f833 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql index c04f4747d9..c04f4747d9 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql b/synapse/storage/databases/main/schema/full_schemas/16/push.sql index e44465cf45..e44465cf45 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/push.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql index 318f0d9aa5..318f0d9aa5 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql index d47da3b12f..d47da3b12f 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql b/synapse/storage/databases/main/schema/full_schemas/16/state.sql index 96391a8f0e..96391a8f0e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/state.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql index 17e67bedac..17e67bedac 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql b/synapse/storage/databases/main/schema/full_schemas/16/users.sql index f013aa8b18..f013aa8b18 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql +++ b/synapse/storage/databases/main/schema/full_schemas/16/users.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres index 889a9a0ce4..889a9a0ce4 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite index a0411ede7e..a0411ede7e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql index 91d21b2921..91d21b2921 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/databases/main/schema/full_schemas/README.md index c00f287190..c00f287190 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.md +++ b/synapse/storage/databases/main/schema/full_schemas/README.md diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/databases/main/search.py index 13f49d8060..f01cf2fd02 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -16,17 +16,13 @@ import logging import re from collections import namedtuple - -from six import string_types - -from canonicaljson import json - -from twisted.internet import defer +from typing import List, Optional, Set from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.storage.database import Database +from synapse.events import EventBase +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -92,16 +88,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -110,16 +106,15 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.db.updates.register_noop_background_update( + self.db_pool.updates.register_noop_background_update( self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) - @defer.inlineCallbacks - def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -144,7 +139,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.db.cursor_to_dict(txn) + rows = self.db_pool.cursor_to_dict(txn) if not rows: return 0 @@ -159,7 +154,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): stream_ordering = row["stream_ordering"] origin_server_ts = row["origin_server_ts"] try: - event_json = json.loads(row["json"]) + event_json = db_to_json(row["json"]) content = event_json["content"] except Exception: continue @@ -180,7 +175,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # skip over it. continue - if not isinstance(value, string_types): + if not isinstance(value, str): # If the event body, name or topic isn't a string # then skip over it continue @@ -204,23 +199,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): "rows_inserted": rows_inserted + len(event_search_rows), } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) return len(event_search_rows) - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + await self.db_pool.updates._end_background_update( + self.EVENT_SEARCH_UPDATE_NAME + ) return result - @defer.inlineCallbacks - def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search(self, progress, batch_size): """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ @@ -257,15 +253,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.db.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME ) return 1 - @defer.inlineCallbacks - def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -290,14 +285,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) conn.set_session(autocommit=False) - yield self.db.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.db.runInteraction( + await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -327,18 +322,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): "have_added_indexes": True, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) return len(rows), True - num_rows, finished = yield self.db.runInteraction( + num_rows, finished = await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_ORDER_UPDATE_NAME ) @@ -346,11 +341,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SearchStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): + async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. Args: @@ -427,15 +421,15 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self.db.execute( - "search_msgs", self.db.cursor_to_dict, sql, *args + results = await self.db_pool.execute( + "search_msgs", self.db_pool.cursor_to_dict, sql, *args ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -444,12 +438,12 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db.execute( - "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args + count_results = await self.db_pool.execute( + "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -464,19 +458,25 @@ class SearchStore(SearchBackgroundUpdateStore): "count": count, } - @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + async def search_rooms( + self, + room_ids: List[str], + search_term: str, + keys: List[str], + limit, + pagination_token: Optional[str] = None, + ) -> List[dict]: """Performs a full text search over events with given keys. Args: - room_id (list): The room_ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - pagination_token (str): A pagination token previously returned + room_ids: The room_ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", + "content.name", "content.topic" + pagination_token: A pagination token previously returned Returns: - list of dicts + Each match as a dictionary. """ clauses = [] @@ -579,15 +579,15 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self.db.execute( - "search_rooms", self.db.cursor_to_dict, sql, *args + results = await self.db_pool.execute( + "search_rooms", self.db_pool.cursor_to_dict, sql, *args ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -596,12 +596,12 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db.execute( - "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args + count_results = await self.db_pool.execute( + "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -621,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore): "count": count, } - def _find_highlights_in_postgres(self, search_query, events): + async def _find_highlights_in_postgres( + self, search_query: str, events: List[EventBase] + ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -629,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore): highlight the matching parts. Args: - search_query (str) - events (list): A list of events + search_query + events: A list of events Returns: - deferred : A set of strings. + A set of strings. """ def f(txn): @@ -686,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore): return highlight_words - return self.db.runInteraction("_find_highlights", f) + return await self.db_pool.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/databases/main/signatures.py index 36244d9f5d..c8c67953e4 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unpaddedbase64 import encode_base64 +from typing import Dict, Iterable, List, Tuple -from twisted.internet import defer +from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -31,18 +32,38 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) + + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ - @defer.inlineCallbacks - def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes(event_ids) + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ + hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() @@ -50,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/databases/main/state.py index 347cc50778..5c6168e301 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -16,17 +16,18 @@ import collections.abc import logging from collections import namedtuple - -from twisted.internet import defer +from typing import Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -54,7 +55,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -93,7 +94,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # We really should have an entry in the rooms table for every room we # care about, but let's be a bit paranoid (at least while the background # update is happening) to avoid breaking existing rooms. - version = await self.db.simple_select_one_onecol( + version = await self.db_pool.simple_select_one_onecol( table="rooms", keyvalues={"room_id": room_id}, retcol="room_version", @@ -108,28 +109,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_create_event_for_room(room_id) return create_event.content.get("room_version", "1") - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): + async def get_room_predecessor(self, room_id: str) -> Optional[dict]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[dict|None]: A dictionary containing the structure of the predecessor - field from the room's create event. The structure is subject to other servers, - but it is expected to be: - * room_id (str): The room ID of the predecessor room - * event_id (str): The ID of the tombstone event in the predecessor room + A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room - None if a predecessor key is not found, or is not a dictionary. + None if a predecessor key is not found, or is not a dictionary. Raises: NotFoundError if the given room is unknown """ # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) + create_event = await self.get_create_event_for_room(room_id) # Retrieve the predecessor key of the create event predecessor = create_event.content.get("predecessor", None) @@ -140,20 +140,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return predecessor - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): + async def get_create_event_for_room(self, room_id: str) -> EventBase: """Get the create state event for a room. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[EventBase]: The room creation event. + The room creation event. Raises: NotFoundError if the room is unknown """ - state_ids = yield self.get_current_state_ids(room_id) + state_ids = await self.get_current_state_ids(room_id) create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end @@ -161,19 +160,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) + create_event = await self.get_event(create_id) return create_event @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): + async def get_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. Args: - room_id (str) + room_id: The room to get the state IDs of. Returns: - deferred: dict of (type, state_key) -> event_id + The current state of the room. """ def _get_current_state_ids_txn(txn): @@ -186,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - def get_filtered_current_state_ids( + async def get_filtered_current_state_ids( self, room_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state @@ -204,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): from the database. Returns: - defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() if not where_clause: # We delegate to the cached version - return self.get_current_state_ids(room_id) + return await self.get_current_state_ids(room_id) def _get_filtered_current_state_ids_txn(txn): results = {} @@ -233,22 +232,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: """Get canonical alias for room, if any Args: - room_id (str) + room_id: The room ID Returns: - Deferred[str|None]: The canonical alias, if any + The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids( + state = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) ) @@ -256,15 +254,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_id: return - event = yield self.get_event(event_id, allow_none=True) + event = await self.get_event(event_id, allow_none=True) if not event: return return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -276,12 +274,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cached_method_name="_get_state_group_for_event", list_name="event_ids", num_args=1, - inlineCallbacks=True, ) - def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {row["event_id"]: row["state_group"] for row in rows} - @defer.inlineCallbacks - def get_referenced_state_groups(self, state_groups): + async def get_referenced_state_groups( + self, state_groups: Iterable[int] + ) -> Set[int]: """Check if the state groups are referenced by events. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[set[int]]: The subset of state groups that are - referenced. + The subset of state groups that are referenced. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -322,25 +319,25 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", columns=["state_group"], ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, ) @@ -353,6 +350,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): last_room_id = progress.get("last_room_id", "") def _background_remove_left_rooms_txn(txn): + # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events WHERE room_id > ? ORDER BY room_id LIMIT ? @@ -363,35 +361,79 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): if not room_ids: return True, set() + ########################################################################### + # + # exclude rooms where we have active members + sql = """ SELECT room_id - FROM current_state_events + FROM local_current_membership WHERE room_id > ? AND room_id <= ? - AND type = 'm.room.member' AND membership = 'join' - AND state_key LIKE ? GROUP BY room_id """ - txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - + txn.execute(sql, (last_room_id, room_ids[-1])) joined_room_ids = {row[0] for row in txn} + to_delete = set(room_ids) - joined_room_ids + + ########################################################################### + # + # exclude rooms which we are in the process of constructing; these otherwise + # qualify as "rooms with no local users", and would have their + # forward extremities cleaned up. + + # the following query will return a list of rooms which have forward + # extremities that are *not* also the create event in the room - ie + # those that are not being created currently. + + sql = """ + SELECT DISTINCT efe.room_id + FROM event_forward_extremities efe + LEFT JOIN current_state_events cse ON + cse.event_id = efe.event_id + AND cse.type = 'm.room.create' + AND cse.state_key = '' + WHERE + cse.event_id IS NULL + AND efe.room_id > ? AND efe.room_id <= ? + """ + + txn.execute(sql, (last_room_id, room_ids[-1])) + + # build a set of those rooms within `to_delete` that do not appear in + # the above, leaving us with the rooms in `to_delete` that *are* being + # created. + creating_rooms = to_delete.difference(row[0] for row in txn) + logger.info("skipping rooms which are being created: %s", creating_rooms) + + # now remove the rooms being created from the list of those to delete. + # + # (we could have just taken the intersection of `to_delete` with the result + # of the sql query, but it's useful to be able to log `creating_rooms`; and + # having done so, it's quicker to remove the (few) creating rooms from + # `to_delete` than it is to form the intersection with the (larger) list of + # not-creating-rooms) + + to_delete -= creating_rooms - left_rooms = set(room_ids) - joined_room_ids + ########################################################################### + # + # now clear the state for the rooms - logger.info("Deleting current state left rooms: %r", left_rooms) + logger.info("Deleting current state left rooms: %r", to_delete) # First we get all users that we still think were joined to the # room. This is so that we can mark those device lists as # potentially stale, since there may have been a period where the # server didn't share a room with the remote user and therefore may # have missed any device updates. - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, retcols=("state_key",), ) @@ -399,23 +441,23 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): potentially_left_users = {row["state_key"] for row in rows} # Now lets actually delete the rooms from the DB. - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="current_state_events", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="event_forward_extremities", column="room_id", - iterable=left_rooms, + iterable=to_delete, keyvalues={}, ) - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.DELETE_CURRENT_STATE_UPDATE_NAME, {"last_room_id": room_ids[-1]}, @@ -423,12 +465,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): return False, potentially_left_users - finished, potentially_left_users = await self.db.runInteraction( + finished, potentially_left_users = await self.db_pool.runInteraction( "_background_remove_left_rooms", _background_remove_left_rooms_txn ) if finished: - await self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.DELETE_CURRENT_STATE_UPDATE_NAME ) @@ -463,5 +505,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 725e12507f..356623fc6e 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Any, Dict, List, Tuple from synapse.storage._base import SQLBaseStore @@ -23,7 +22,9 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore): if it's new state. Args: - prev_stream_id (int): point to get changes since (exclusive) - max_stream_id (int): the point that we know has been correctly persisted + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. Returns: - Deferred[tuple[int, list[dict]]: A tuple consisting of: + A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. @@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return defer.succeed((max_stream_id, [])) + return (max_stream_id, []) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -100,22 +101,22 @@ class StateDeltasStore(SQLBaseStore): ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.db.cursor_to_dict(txn) + return clipped_stream_id, self.db_pool.cursor_to_dict(txn) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self.db.simple_select_one_onecol_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, retcol="COALESCE(MAX(stream_id), -1)", ) - def get_max_stream_id_in_current_state_deltas(self): - return self.db.runInteraction( + async def get_max_stream_id_in_current_state_deltas(self): + return await self.db_pool.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/databases/main/stats.py index 380c1ec7da..55a250ef06 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -15,14 +15,15 @@ # limitations under the License. import logging +from collections import Counter from itertools import chain +from typing import Any, Dict, List, Optional, Tuple -from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership -from synapse.storage.data_stores.main.state_deltas import StateDeltasStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -59,7 +60,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -69,17 +70,20 @@ class StatsStore(StateDeltasStore): self.stats_delta_processing_lock = DeferredLock() - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( + "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 + ) + self.db_pool.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.db.updates.register_noop_background_update("populate_stats_cleanup") - self.db.updates.register_noop_background_update("populate_stats_prepare") + self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") + self.db_pool.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -97,13 +101,14 @@ class StatsStore(StateDeltasStore): """ return (ts // self.stats_bucket_size) * self.stats_bucket_size - @defer.inlineCallbacks - def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db_pool.updates._end_background_update( + "populate_stats_process_users" + ) return 1 last_user_id = progress.get("last_user_id", "") @@ -118,35 +123,57 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db_pool.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db_pool.updates._end_background_update( + "populate_stats_process_users" + ) return 1 for user_id in users_to_work_on: - yield self._calculate_and_set_initial_state_for_user(user_id) + await self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_stats_process_users", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) return len(users_to_work_on) - @defer.inlineCallbacks - def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms(self, progress, batch_size): + """ + This was a background update which regenerated statistics for rooms. + + It has been replaced by StatsStore._populate_stats_process_rooms_2. This background + job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure + someone upgrading from <v1.0.0, this background task has been turned into a no-op + so that the potentially expensive task is not run twice. + + Further context: https://github.com/matrix-org/synapse/pull/7977 + """ + await self.db_pool.updates._end_background_update( + "populate_stats_process_rooms" + ) + return 1 + + async def _populate_stats_process_rooms_2(self, progress, batch_size): """ This is a background update which regenerates statistics for rooms. + + It replaces StatsStore._populate_stats_process_rooms. See its docstring for the + reasoning. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db_pool.updates._end_background_update( + "populate_stats_process_rooms_2" + ) return 1 last_room_id = progress.get("last_room_id", "") @@ -161,48 +188,68 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.db.runInteraction( - "populate_stats_rooms_get_batch", _get_next_batch + rooms_to_work_on = await self.db_pool.runInteraction( + "populate_stats_rooms_2_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db_pool.updates._end_background_update( + "populate_stats_process_rooms_2" + ) return 1 for room_id in rooms_to_work_on: - yield self._calculate_and_set_initial_state_for_room(room_id) + await self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.db.runInteraction( - "_populate_stats_process_rooms", - self.db.updates._background_update_progress_txn, - "populate_stats_process_rooms", + await self.db_pool.runInteraction( + "_populate_stats_process_rooms_2", + self.db_pool.updates._background_update_progress_txn, + "populate_stats_process_rooms_2", progress, ) return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", desc="stats_incremental_position", ) - def update_room_state(self, room_id, fields): - """ + async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: + """Update the state of a room. + + fields can contain the following keys with string values: + * join_rules + * history_visibility + * encryption + * name + * topic + * avatar + * canonical_alias + + A is_federatable key can also be included with a boolean value. + Args: - room_id (str) - fields (dict[str:Any]) + room_id: The room ID to update the state of. + fields: The fields to update. This can include a partial list of the + above fields to only update some room information. """ - - # For whatever reason some of the fields may contain null bytes, which - # postgres isn't a fan of, so we replace those fields with null. + # Ensure that the values to update are valid, they should be strings and + # not contain any null bytes. + # + # Invalid data gets overwritten with null. + # + # Note that a missing value should not be overwritten (it keeps the + # previous value). + sentinel = object() for col in ( "join_rules", "history_visibility", @@ -212,32 +259,34 @@ class StatsStore(StateDeltasStore): "avatar", "canonical_alias", ): - field = fields.get(col) - if field and "\0" in field: + field = fields.get(col, sentinel) + if field is not sentinel and (not isinstance(field, str) or "\0" in field): fields[col] = None - return self.db.simple_upsert( + await self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, desc="update_room_state", ) - def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): + async def get_statistics_for_subject( + self, stats_type: str, stats_id: str, start: str, size: int = 100 + ) -> List[dict]: """ Get statistics for a given subject. Args: - stats_type (str): The type of subject - stats_id (str): The ID of the subject (e.g. room_id or user_id) - start (int): Pagination start. Number of entries, not timestamp. - size (int): How many entries to return. + stats_type: The type of subject + stats_id: The ID of the subject (e.g. room_id or user_id) + start: Pagination start. Number of entries, not timestamp. + size: How many entries to return. Returns: - Deferred[list[dict]], where the dict has the keys of + A list of dicts, where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -258,7 +307,7 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self.db.simple_select_list_paginate_txn( + slice_list = self.db_pool.simple_select_list_paginate_txn( txn, table + "_historical", "end_ts", @@ -272,7 +321,7 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -280,29 +329,28 @@ class StatsStore(StateDeltasStore): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", allow_none=True, ) - def bulk_update_stats_delta(self, ts, updates, stream_id): + async def bulk_update_stats_delta( + self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. Args: - ts (int): Current timestamp in ms - updates(dict[str, dict[str, dict[str, Counter]]]): The updates to - commit as a mapping stats_type -> stats_id -> field -> delta. - stream_id (int): Current position. - - Returns: - Deferred + ts: Current timestamp in ms + updates: The updates to commit as a mapping of + stats_type -> stats_id -> field -> delta. + stream_id: Current position. """ def _bulk_update_stats_delta_txn(txn): @@ -320,45 +368,44 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) - def update_stats_delta( + async def update_stats_delta( self, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): + ts: int, + stats_type: str, + stats_id: str, + fields: Dict[str, int], + complete_with_stream_id: Optional[int], + absolute_field_overrides: Optional[Dict[str, int]] = None, + ) -> None: """ Updates the statistics for a subject, with a delta (difference/relative change). Args: - ts (int): timestamp of the change - stats_type (str): "room" or "user" – the kind of subject - stats_id (str): the subject's ID (room ID or user ID) - fields (dict[str, int]): Deltas of stats values. - complete_with_stream_id (int, optional): + ts: timestamp of the change + stats_type: "room" or "user" – the kind of subject + stats_id: the subject's ID (room ID or user ID) + fields: Deltas of stats values. + complete_with_stream_id: If supplied, converts an incomplete row into a complete row, with the supplied stream_id marked as the stream_id where the row was completed. - absolute_field_overrides (dict[str, int]): Current stats values - (i.e. not deltas) of absolute fields. - Does not work with per-slice fields. + absolute_field_overrides: Current stats values (i.e. not deltas) of + absolute fields. Does not work with per-slice fields. """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -493,17 +540,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self.db.simple_select_one_txn( + current_row = self.db_pool.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self.db.simple_insert_txn(txn, table, merged_dict) + self.db_pool.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self.db.simple_update_one_txn(txn, table, keyvalues, current_row) + self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -590,11 +637,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self.db.simple_select_one_txn( + src_row = self.db_pool.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.db.simple_select_one_txn( + dest_current_row = self.db_pool.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -610,25 +657,28 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self.db.simple_insert_txn(txn, into_table, merged_dict) + self.db_pool.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db_pool.simple_update_txn( + txn, into_table, all_dest_keyvalues, src_row + ) - def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + async def get_changes_room_total_events_and_bytes( + self, min_pos: int, max_pos: int + ) -> Dict[str, Dict[str, int]]: """Fetches the counts of events in the given range of stream IDs. Args: - min_pos (int) - max_pos (int) + min_pos + max_pos Returns: - Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field - changes. + Mapping of room ID to field changes. """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -696,22 +746,22 @@ class StatsStore(StateDeltasStore): return room_deltas, user_deltas - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_room(self, room_id): + async def _calculate_and_set_initial_state_for_room( + self, room_id: str + ) -> Tuple[dict, dict, int]: """Calculate and insert an entry into room_stats_current. Args: - room_id (str) + room_id: The room ID under calculation. Returns: - Deferred[tuple[dict, dict, int]]: A tuple of room state, membership - counts and stream position. + A tuple of room state, membership counts and stream position. """ def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -767,11 +817,11 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.db.runInteraction( + ) = await self.db_pool.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = yield self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) room_state = { "join_rules": None, @@ -806,11 +856,11 @@ class StatsStore(StateDeltasStore): event.content.get("m.federate", True) is True ) - yield self.update_room_state(room_id, room_state) + await self.update_room_state(room_id, room_state) local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="room", stats_id=room_id, @@ -826,8 +876,7 @@ class StatsStore(StateDeltasStore): }, ) - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_user(self, user_id): + async def _calculate_and_set_initial_state_for_user(self, user_id): def _calculate_and_set_initial_state_for_user_txn(txn): pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) @@ -842,12 +891,12 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.db.runInteraction( + joined_rooms, pos = await self.db_pool.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="user", stats_id=user_id, diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/databases/main/stream.py index e89f0bffb5..be6df8a6d1 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,19 +39,27 @@ what sort order was used: import abc import logging from collections import namedtuple - -from six.moves import range +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from twisted.internet import defer +from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine -from synapse.types import RoomStreamToken +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.types import Collection, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -69,8 +77,12 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): + direction: str, + column_names: Tuple[str, str], + from_token: Optional[Tuple[int, int]], + to_token: Optional[Tuple[int, int]], + engine: BaseDatabaseEngine, +) -> str: """Creates an SQL expression to bound the columns by the pagination tokens. @@ -91,21 +103,19 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". + direction: Whether we're paginating backwards("b") or forwards ("f"). + column_names: The column names to bound. Must *not* be user defined as + these get inserted directly into the SQL statement without escapes. + from_token: The start point for the pagination. This is an exclusive + minimum bound if direction is "f", and an inclusive maximum bound if + direction is "b". + to_token: The endpoint point for the pagination. This is an inclusive + maximum bound if direction is "f", and an exclusive minimum bound if + direction is "b". engine: The database engine to generate the clauses for Returns: - str: The sql expression + The sql expression """ assert direction in ("b", "f") @@ -133,7 +143,12 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) -def _make_generic_sql_bound(bound, column_names, values, engine): +def _make_generic_sql_bound( + bound: str, + column_names: Tuple[str, str], + values: Tuple[Optional[int], int], + engine: BaseDatabaseEngine, +) -> str: """Create an SQL expression that bounds the given column names by the values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. @@ -143,18 +158,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine): out manually. Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", + bound: The comparison operator to use. One of ">", "<", ">=", "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined + names: The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int|None, int]): The values to bound the columns by. If + values: The values to bound the columns by. If the first value is None then only creates a bound on the second column. engine: The database engine to generate the SQL for Returns: - str + The SQL statement """ assert bound in (">", "<", ">=", "<=") @@ -194,7 +209,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): ) -def filter_to_clause(event_filter): +def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -252,11 +267,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super(StreamWorkerStore, self).__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._send_federation = hs.should_send_federation() + self._federation_shard_config = hs.config.worker.federation_shard_config + + # If we're a process that sends federation we may need to reset the + # `federation_stream_position` table to match the current sharding + # config. We don't do this now as otherwise two processes could conflict + # during startup which would cause one to die. + self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self.db.get_cache_dict( + event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -275,41 +300,42 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: raise NotImplementedError() @abc.abstractmethod - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() - @defer.inlineCallbacks - def get_room_events_stream_for_rooms( - self, room_ids, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_rooms( + self, + room_ids: Collection[str], + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Dict[str, Tuple[List[EventBase], str]]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_ids + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[dict[str,tuple[list[FrozenEvent], str]]] - A map from room id to a tuple containing: - - list of recent events in the room - - stream ordering key for the start of the chunk of events returned. + A map from room id to a tuple containing: + - list of recent events in the room + - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream - room_ids = yield self._events_stream_cache.get_entities_changed( - room_ids, from_id - ) + room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) if not room_ids: return {} @@ -317,7 +343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): - res = yield make_deferred_yieldable( + res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -337,43 +363,47 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results - def get_rooms_that_changed(self, room_ids, from_key): + def get_rooms_that_changed( + self, room_ids: Collection[str], from_key: str + ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. Args: - room_ids (list) - from_key (str): The room_key portion of a StreamToken + room_ids + from_key: The room_key portion of a StreamToken """ - from_key = RoomStreamToken.parse_stream_token(from_key).stream + from_id = RoomStreamToken.parse_stream_token(from_key).stream return { room_id for room_id in room_ids - if self._events_stream_cache.has_entity_changed(room_id, from_key) + if self._events_stream_cache.has_entity_changed(room_id, from_id) } - @defer.inlineCallbacks - def get_room_events_stream_for_room( - self, room_id, from_key, to_key, limit=0, order="DESC" - ): - + async def get_room_events_stream_for_room( + self, + room_id: str, + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Tuple[List[EventBase], str]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_id + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns the list of - events (in ascending order) and the token from the start of - the chunk of events returned. + The list of events (in ascending order) and the token from the start + of the chunk of events returned. """ if from_key == to_key: return [], from_key @@ -381,9 +411,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) + has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key @@ -401,9 +429,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) + rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -421,8 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user( + self, user_id: str, from_key: str, to_key: str + ) -> List[EventBase]: from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -451,9 +480,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.db.runInteraction("get_membership_changes_for_user", f) + rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -461,27 +490,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): + async def get_recent_events_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[EventBase], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of events and a token pointing to the start of the returned + events. The events returned are in ascending order. """ - rows, token = yield self.get_recent_event_ids_for_room( + rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -489,20 +517,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): + async def get_recent_event_ids_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[_EventDictReturn], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of _EventDictReturn and a token pointing to the start of the + returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: @@ -510,7 +537,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.db.runInteraction( + rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -523,16 +550,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + async def get_room_event_before_stream_ordering( + self, room_id: str, stream_ordering: int + ) -> Tuple[int, int, str]: """Gets details of the first event in a room at or before a stream ordering Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - Deferred[(int, int, str)]: - (stream ordering, topological ordering, event_id) + A tuple of (stream ordering, topological ordering, event_id) """ def _f(txn): @@ -547,76 +575,100 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.db.runInteraction("get_room_event_before_stream_ordering", _f) + return await self.db_pool.runInteraction( + "get_room_event_before_stream_ordering", _f + ) - @defer.inlineCallbacks - def get_room_events_max_id(self, room_id=None): + async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ - token = yield self.get_room_max_stream_ordering() + token = self.get_room_max_stream_ordering() if room_id is None: return "s%d" % (token,) else: - topo = yield self.db.runInteraction( + topo = await self.db_pool.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) - def get_stream_token_for_event(self, event_id): + async def get_stream_id_for_event(self, event_id: str) -> int: + """The stream ID for an event + Args: + event_id: The id of the event to look up a stream token for. + Raises: + StoreError if the event wasn't in the database. + Returns: + A stream ID. + """ + return await self.db_pool.runInteraction( + "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, + ) + + def get_stream_id_for_event_txn( + self, txn: LoggingTransaction, event_id: str, allow_none=False, + ) -> int: + return self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"event_id": event_id}, + retcol="stream_ordering", + allow_none=allow_none, + ) + + async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "s%d" stream token. + A "s%d" stream token. """ - return self.db.simple_select_one_onecol( - table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) + stream_id = await self.get_stream_id_for_event(event_id) + return "s%d" % (stream_id,) - def get_topological_token_for_event(self, event_id): + async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "t%d-%d" topological token. + A "t%d-%d" topological token. """ - return self.db.simple_select_one( + row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) + return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - def get_max_topological_token(self, room_id, stream_key): + async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: """Get the max topological token in a room before the given stream ordering. Args: - room_id (str) - stream_key (int) + room_id + stream_key Returns: - Deferred[int] + The maximum topological token. """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.db.execute( + row = await self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) + ) + return row[0][0] if row else 0 - def _get_max_topological_txn(self, txn, room_id): + def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int: txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), @@ -626,16 +678,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows[0][0] if rows else 0 @staticmethod - def _set_before_and_after(events, rows, topo_order=True): + def _set_before_and_after( + events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True + ): """Inserts ordering information to events' internal metadata from the DB rows. Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. + events + rows + topo_order: Whether the events were ordered topologically or by stream + ordering. If true then all rows should have a non null + topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering @@ -648,25 +702,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): + async def get_events_around( + self, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter] = None, + ) -> dict: """Retrieve events and pagination tokens around a given event in a room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict """ - results = yield self.db.runInteraction( + results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -676,11 +724,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self.get_events_as_list( + events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) - events_after = yield self.get_events_as_list( + events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) @@ -692,29 +740,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): + self, + txn: LoggingTransaction, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter], + ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) + room_id + event_id + before_limit + after_limit + event_filter Returns: dict """ - results = self.db.simple_select_one_txn( + results = self.db_pool.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, retcols=["stream_ordering", "topological_ordering"], ) + # This cannot happen as `allow_none=False`. + assert results is not None + # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( @@ -750,22 +807,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): + async def get_all_new_events_stream( + self, from_id: int, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). """ def get_all_new_events_stream_txn(txn): @@ -787,63 +845,134 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events - def get_federation_out_pos(self, typ): - return self.db.simple_select_one_onecol( + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db_pool.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db_pool.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, desc="get_federation_out_pos", ) - def update_federation_out_pos(self, typ, stream_id): - return self.db.simple_update_one( + async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: + if self._need_to_reset_federation_stream_positions: + await self.db_pool.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + await self.db_pool.simple_update_one( table="federation_stream_position", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) - def has_room_changed_since(self, room_id, stream_id): + def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None: + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db_pool.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db_pool.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + + def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, - txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): + txn: LoggingTransaction, + room_id: str, + from_token: RoomStreamToken, + to_token: Optional[RoomStreamToken] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[_EventDictReturn], str]: """Returns list of events before or after a given token. Args: txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to + room_id + from_token: The token used to stream from + to_token: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. + A list of _EventDictReturn and a token that points to the end of the + result set. If no events are returned then the end of the stream has + been reached (i.e. there are no events between `from_token` and + `to_token`), or `limit` is zero. """ assert int(limit) >= 0 @@ -927,35 +1056,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, str(next_token) - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): + async def paginate_room_events( + self, + room_id: str, + from_key: str, + to_key: Optional[str] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[EventBase], str]: """Returns list of events before or after a given token. Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. + room_id + from_key: The token used to stream from + to_key: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). + The results as a list of events and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between `from_key` + and `to_key`). """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db.runInteraction( + rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, @@ -966,7 +1098,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -976,8 +1108,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: return self._stream_id_gen.get_current_token() - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/databases/main/tags.py index 4219018302..96ffe26cc9 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,14 +15,12 @@ # limitations under the License. import logging +from typing import Dict, List, Tuple -from six.moves import range - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore +from synapse.storage._base import db_to_json +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -30,43 +28,53 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - def get_tags_for_user(self, user_id): + async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for a user. Args: - user_id(str): The user to get the tags for. + user_id: The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. + A mapping from room_id strings to dicts mapping from tag strings to + tag content. """ - deferred = self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = json.loads(row["content"]) - return tags_by_room + tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]] + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = db_to_json(row["content"]) + return tags_by_room - return deferred + async def get_all_updated_tags( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for tags replication stream. - @defer.inlineCallbacks - def get_all_updated_tags(self, last_id, current_id, limit): - """Get all the client tags that have changed on the server Args: - last_id(int): The position to fetch from. - current_id(int): The position to fetch up to. + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + Returns: - A deferred list of tuples of stream_id int, user_id string, - room_id string, tag string and content string. + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + if last_id == current_id: - return [] + return [], current_id, False def get_all_updated_tags_txn(txn): sql = ( @@ -78,7 +86,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.db.runInteraction( + tag_ids = await self.db_pool.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -89,35 +97,43 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (user_id, room_id)) tags = [] for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) + tags.append(json_encoder.encode(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, user_id, room_id, tag_json)) + results.append((stream_id, (user_id, room_id, tag_json))) return results batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.db.runInteraction( + tags = await self.db_pool.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], ) results.extend(tags) - return results + limited = False + upto_token = current_id + if len(results) >= limit: + upto_token = results[-1][0] + limited = True + + return results, upto_token, limited - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): + async def get_updated_tags( + self, user_id: str, stream_id: int + ) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for the rooms where the tags have changed since the given version Args: user_id(str): The user to get the tags for. stream_id(int): The earliest update to get for the user. + Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. + A mapping from room_id strings to lists of tag strings for all the + rooms that changed since the stream_id token. """ def get_updated_tags_txn(txn): @@ -135,52 +151,58 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.db.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_updated_tags", get_updated_tags_txn ) results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) + tags_by_room = await self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room.get(room_id, {}) return results - def get_tags_for_room(self, user_id, room_id): + async def get_tags_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the tags for the given room + Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for + user_id: The user to get tags for + room_id: The room to get tags for + Returns: - A deferred list of string tags. + A mapping of tags to tag content. """ - return self.db.simple_select_list( + rows = await self.db_pool.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows} ) + return {row["tag"]: db_to_json(row["content"]) for row in rows} class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: """Add a tag to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the tag has been added. + The next account data ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -188,19 +210,18 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.db.runInteraction("add_tag", add_tag_txn, next_id) + with await self._account_data_id_gen.get_next() as next_id: + await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: """Remove a tag from a room for a user. + Returns: - A deferred that completes once the tag has been removed + The next account data ID. """ def remove_tag_txn(txn, next_id): @@ -211,22 +232,23 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) + with await self._account_data_id_gen.get_next() as next_id: + await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_revision_txn(self, txn, user_id, room_id, next_id): + def _update_revision_txn( + self, txn, user_id: str, room_id: str, next_id: int + ) -> None: """Update the latest revision of the tags for the given user and room. Args: txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. + user_id: The ID of the user. + room_id: The ID of the room. + next_id: The the revision to advance to. """ txn.call_after( diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/databases/main/transactions.py index a9bf457939..5b31aab700 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,14 +15,14 @@ import logging from collections import namedtuple +from typing import Optional, Tuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -46,7 +46,7 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) @@ -57,21 +57,23 @@ class TransactionStore(SQLBaseStore): expiry_ms=5 * 60 * 1000, ) - def get_received_txn_response(self, transaction_id, origin): + async def get_received_txn_response( + self, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). Args: - transaction_id (str) - origin(str) + transaction_id + origin Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) + None if we have not previously responded to this transaction or a + 2-tuple of (int, dict) """ - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -79,7 +81,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -100,20 +102,21 @@ class TransactionStore(SQLBaseStore): else: return None - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and + async def set_received_txn_response( + self, transaction_id: str, origin: str, code: int, response_dict: JsonDict + ) -> None: + """Persist the response we returned for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) + transaction_id: The incoming transaction ID. + origin: The origin server. + code: The response code. + response_dict: The response, to be encoded into JSON. """ - return self.db.simple_insert( + await self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -126,8 +129,7 @@ class TransactionStore(SQLBaseStore): desc="set_received_txn_response", ) - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): + async def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. Args: @@ -142,7 +144,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.db.runInteraction( + result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -154,7 +156,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -167,21 +169,25 @@ class TransactionStore(SQLBaseStore): else: return None - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): + async def set_destination_retry_timings( + self, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms + destination + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms """ self._destination_retry_cache.pop(destination, None) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -221,7 +227,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self.db.simple_select_one_txn( + prev_row = self.db_pool.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -230,7 +236,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="destinations", values={ @@ -241,7 +247,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -257,13 +263,13 @@ class TransactionStore(SQLBaseStore): "cleanup_transactions", self._cleanup_transactions ) - def _cleanup_transactions(self): + async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db.runInteraction( + await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 1d8ee22fb1..b89668d561 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,15 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr -import synapse.util.stringutils as stringutils from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict +from synapse.util import json_encoder, stringutils @attr.s @@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore): StoreError if a unique session ID cannot be generated. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) # autogen a session ID and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -81,7 +81,7 @@ class UIAuthWorkerStore(SQLBaseStore): session_id = stringutils.random_string(24) try: - await self.db.simple_insert( + await self.db_pool.simple_insert( table="ui_auth_sessions", values={ "session_id": session_id, @@ -97,7 +97,7 @@ class UIAuthWorkerStore(SQLBaseStore): return UIAuthSessionData( session_id, clientdict, uri, method, description ) - except self.db.engine.module.IntegrityError: + except self.db_pool.engine.module.IntegrityError: attempts += 1 raise StoreError(500, "Couldn't generate a session ID.") @@ -111,14 +111,14 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session is not found. """ - result = await self.db.simple_select_one( + result = await self.db_pool.simple_select_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("clientdict", "uri", "method", "description"), desc="get_ui_auth_session", ) - result["clientdict"] = json.loads(result["clientdict"]) + result["clientdict"] = db_to_json(result["clientdict"]) return UIAuthSessionData(session_id, **result) @@ -140,13 +140,13 @@ class UIAuthWorkerStore(SQLBaseStore): # Note that we need to allow for the same stage to complete multiple # times here so that registration is idempotent. try: - await self.db.simple_upsert( + await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, - values={"result": json.dumps(result)}, + values={"result": json_encoder.encode(result)}, desc="mark_ui_auth_stage_complete", ) - except self.db.engine.module.IntegrityError: + except self.db_pool.engine.module.IntegrityError: raise StoreError(400, "Unknown session ID: %s" % (session_id,)) async def get_completed_ui_auth_stages( @@ -162,13 +162,13 @@ class UIAuthWorkerStore(SQLBaseStore): that auth-type. """ results = {} - for row in await self.db.simple_select_list( + for row in await self.db_pool.simple_select_list( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id}, retcols=("stage_type", "result"), desc="get_completed_ui_auth_stages", ): - results[row["stage_type"]] = json.loads(row["result"]) + results[row["stage_type"]] = db_to_json(row["result"]) return results @@ -184,9 +184,9 @@ class UIAuthWorkerStore(SQLBaseStore): The dictionary from the client root level, not the 'auth' key. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) - self.db.simple_update_one( + await self.db_pool.simple_update_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, updatevalues={"clientdict": clientdict_json}, @@ -206,7 +206,7 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session cannot be found. """ - await self.db.runInteraction( + await self.db_pool.runInteraction( "set_ui_auth_session_data", self._set_ui_auth_session_data_txn, session_id, @@ -214,24 +214,26 @@ class UIAuthWorkerStore(SQLBaseStore): value, ) - def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + def _set_ui_auth_session_data_txn( + self, txn: LoggingTransaction, session_id: str, key: str, value: Any + ): # Get the current value. - result = self.db.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) + ) # type: Dict[str, Any] # type: ignore # Update it and add it back to the database. - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) serverdict[key] = value - self.db.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, - updatevalues={"serverdict": json.dumps(serverdict)}, + updatevalues={"serverdict": json_encoder.encode(serverdict)}, ) async def get_ui_auth_session_data( @@ -247,20 +249,48 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session cannot be found. """ - result = await self.db.simple_select_one( + result = await self.db_pool.simple_select_one( table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), desc="get_ui_auth_session_data", ) - serverdict = json.loads(result["serverdict"]) + serverdict = db_to_json(result["serverdict"]) return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -269,20 +299,31 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, ) - def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + def _delete_old_ui_auth_sessions_txn( + self, txn: LoggingTransaction, expiration_time: int + ): # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions_credentials", column="session_id", @@ -291,7 +332,7 @@ class UIAuthStore(UIAuthWorkerStore): ) # Finally, delete the sessions. - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions", column="session_id", diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 6b8130bf0f..f2f9a5799a 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,13 +15,12 @@ import logging import re - -from twisted.internet import defer +from typing import Any, Dict, Iterable, Optional, Set, Tuple from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.data_stores.main.state import StateFilter -from synapse.storage.data_stores.main.state_deltas import StateDeltasStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.state import StateFilter +from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -38,29 +37,28 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) - @defer.inlineCallbacks - def _populate_user_directory_createtables(self, progress, batch_size): + async def _populate_user_directory_createtables(self, progress, batch_size): # Get all the rooms that we want to process. def _make_staging_area(txn): @@ -85,7 +83,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,43 +98,45 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.db.runInteraction( + new_pos = await self.get_max_stream_id_in_current_state_deltas() + await self.db_pool.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + await self.db_pool.simple_insert( + TEMP_TABLE + "_position", {"position": new_pos} + ) - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_createtables" ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_cleanup(self, progress, batch_size): + async def _populate_user_directory_cleanup(self, progress, batch_size): """ Update the user directory stream position, then clean up the old tables. """ - position = yield self.db.simple_select_one_onecol( + position = await self.db_pool.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) - yield self.update_user_directory_stream_pos(position) + await self.update_user_directory_stream_pos(position) def _delete_staging_area(txn): txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self.db.updates._end_background_update("populate_user_directory_cleanup") + await self.db_pool.updates._end_background_update( + "populate_user_directory_cleanup" + ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_process_rooms(self, progress, batch_size): + async def _populate_user_directory_process_rooms(self, progress, batch_size): """ Args: progress (dict) @@ -147,7 +147,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # If we don't have progress filed, delete everything. if not progress: - yield self.delete_all_from_user_dir() + await self.delete_all_from_user_dir() def _get_next_batch(txn): # Only fetch 250 rooms, so we don't fetch too many at once, even @@ -172,13 +172,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return rooms_to_work_on - rooms_to_work_on = yield self.db.runInteraction( + rooms_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_rooms" ) return 1 @@ -191,19 +191,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = yield self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) if is_in_room: - is_public = yield self.is_room_world_readable_or_publicly_joinable( + is_public = await self.is_room_world_readable_or_publicly_joinable( room_id ) - users_with_profile = yield state.get_current_users_in_room(room_id) + users_with_profile = await state.get_current_users_in_room(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. for user_id, profile in users_with_profile.items(): - yield self.update_profile_in_user_dir( + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -217,7 +217,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): to_insert.add(user_id) if to_insert: - yield self.add_users_in_public_rooms(room_id, to_insert) + await self.add_users_in_public_rooms(room_id, to_insert) to_insert.clear() else: for user_id in user_ids: @@ -237,22 +237,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # If it gets too big, stop and write to the database # to prevent storing too much in RAM. if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: - yield self.add_users_who_share_private_room( + await self.add_users_who_share_private_room( room_id, to_insert ) to_insert.clear() if to_insert: - yield self.add_users_who_share_private_room(room_id, to_insert) + await self.add_users_who_share_private_room(room_id, to_insert) to_insert.clear() # We've finished a room. Delete it from the table. - yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + await self.db_pool.simple_delete_one( + TEMP_TABLE + "_rooms", {"room_id": room_id} + ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -265,13 +267,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return processed_event_count - @defer.inlineCallbacks - def _populate_user_directory_process_users(self, progress, batch_size): + async def _populate_user_directory_process_users(self, progress, batch_size): """ If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -297,13 +298,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return users_to_work_on - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -314,26 +315,27 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for user_id in users_to_work_on: - profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) - yield self.update_profile_in_user_dir( + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) # We've finished processing a user. Delete it from the table. - yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + await self.db_pool.simple_delete_one( + TEMP_TABLE + "_users", {"user_id": user_id} + ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", - self.db.updates._background_update_progress_txn, + self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) return len(users_to_work_on) - @defer.inlineCallbacks - def is_room_world_readable_or_publicly_joinable(self, room_id): + async def is_room_world_readable_or_publicly_joinable(self, room_id): """Check if the room is either world_readable or publically joinable """ @@ -343,33 +345,40 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = yield self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) if hist_vis_ev: if hist_vis_ev.content.get("history_visibility") == "world_readable": return True return False - def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + async def update_profile_in_user_dir( + self, user_id: str, display_name: str, avatar_url: str + ) -> None: """ Update or add a user's profile in the user directory. """ + # If the display name or avatar URL are unexpected types, overwrite them. + if not isinstance(display_name, str): + display_name = None + if not isinstance(avatar_url, str): + avatar_url = None def _update_profile_in_user_dir_txn(txn): - new_entry = self.db.simple_upsert_txn( + new_entry = self.db_pool.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -443,7 +452,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self.db.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -456,21 +465,23 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db.runInteraction( + await self.db_pool.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) - def add_users_who_share_private_room(self, room_id, user_id_tuples): + async def add_users_who_share_private_room( + self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + room_id + user_id_tuples: iterable of 2-tuple of user IDs. """ def _add_users_who_share_room_txn(txn): - self.db.simple_upsert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -482,22 +493,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): value_values=None, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) - def add_users_in_public_rooms(self, room_id, user_ids): + async def add_users_in_public_rooms( + self, room_id: str, user_ids: Iterable[str] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_ids (list[str]) + room_id + user_ids """ def _add_users_in_public_rooms_txn(txn): - self.db.simple_upsert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -506,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): value_values=None, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) - def delete_all_from_user_dir(self): + async def delete_all_from_user_dir(self) -> None: """Delete the entire user directory """ @@ -521,13 +534,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.db.runInteraction( + await self.db_pool.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @cached() - def get_user_in_directory(self, user_id): - return self.db.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -535,8 +548,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - def update_user_directory_stream_pos(self, stream_id): - return self.db.simple_update_one( + async def update_user_directory_stream_pos(self, stream_id: str) -> None: + await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -550,47 +563,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryStore, self).__init__(database, db_conn, hs) - def remove_from_user_dir(self, user_id): + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + await self.db_pool.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn + ) - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): + async def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self.db.simple_select_onecol( + user_ids_share_pub = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self.db.simple_select_onecol( + user_ids_share_priv = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -602,39 +616,38 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return user_ids - def remove_user_who_share_room(self, user_id, room_id): + async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None: """ Deletes entries in the users_who_share_*_rooms table. The first user should be a local user. Args: - user_id (str) - room_id (str) + user_id + room_id """ def _remove_user_who_share_room_txn(txn): - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.db.runInteraction( + await self.db_pool.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) - @defer.inlineCallbacks - def get_user_dir_rooms_user_is_in(self, user_id): + async def get_user_dir_rooms_user_is_in(self, user_id): """ Returns the rooms that a user is in. @@ -644,14 +657,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self.db.simple_select_onecol( + rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self.db.simple_select_onecol( + pub_rows = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -662,42 +675,57 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - @defer.inlineCallbacks - def get_rooms_in_common_for_users(self, user_id, other_user_id): - """Given two user_ids find out the list of rooms they share. + @cached() + async def get_shared_rooms_for_users( + self, user_id: str, other_user_id: str + ) -> Set[str]: """ - sql = """ - SELECT room_id FROM ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) AS f1 INNER JOIN ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) f2 USING (room_id) + Returns the rooms that a local user shares with another local or remote user. + + Args: + user_id: The MXID of a local user + other_user_id: The MXID of the other user + + Returns: + A set of room ID's that the users share. """ - rows = yield self.db.execute( - "get_rooms_in_common_for_users", None, sql, user_id, other_user_id + def _get_shared_rooms_for_users_txn(txn): + txn.execute( + """ + SELECT p1.room_id + FROM users_in_public_rooms as p1 + INNER JOIN users_in_public_rooms as p2 + ON p1.room_id = p2.room_id + AND p1.user_id = ? + AND p2.user_id = ? + UNION + SELECT room_id + FROM users_who_share_private_rooms + WHERE + user_id = ? + AND other_user_id = ? + """, + (user_id, other_user_id, user_id, other_user_id), + ) + rows = self.db_pool.cursor_to_dict(txn) + return rows + + rows = await self.db_pool.runInteraction( + "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn ) - return [room_id for room_id, in rows] + return {row["room_id"] for row in rows} - def get_user_directory_stream_pos(self): - return self.db.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", desc="get_user_directory_stream_pos", ) - @defer.inlineCallbacks - def search_user_dir(self, user_id, search_term, limit): + async def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -794,8 +822,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self.db.execute( - "search_user_dir", self.db.cursor_to_dict, sql, *args + results = await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args ) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index ec6b8a4ffd..2f7c95fc74 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -13,35 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList class UserErasureWorkerStore(SQLBaseStore): @cached() - def is_user_erased(self, user_id): + async def is_user_erased(self, user_id: str) -> bool: """ Check if the given user id has requested erasure Args: - user_id (str): full user id to check + user_id: full user id to check Returns: - Deferred[bool]: True if the user has requested erasure + True if the user has requested erasure """ - return self.db.simple_select_onecol( + result = await self.db_pool.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", desc="is_user_erased", - ).addCallback(operator.truth) + ) + return bool(result) - @cachedList( - cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True - ) - def are_users_erased(self, user_ids): + @cachedList(cached_method_name="is_user_erased", list_name="user_ids") + async def are_users_erased(self, user_ids): """ Checks which users in a list have requested erasure @@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore): user_ids (iterable[str]): full user id to check Returns: - Deferred[dict[str, bool]]: + dict[str, bool]: for each user, whether the user has requested erasure. """ # this serves the dual purpose of (a) making sure we can do len and # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -65,16 +62,15 @@ class UserErasureWorkerStore(SQLBaseStore): ) erased_users = {row["user_id"] for row in rows} - res = {u: u in erased_users for u in user_ids} - return res + return {u: u in erased_users for u in user_ids} class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id): + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: - user_id (str): full user_id to be erased + user_id: full user_id to be erased """ def f(txn): @@ -88,4 +84,26 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) + + async def mark_user_not_erased(self, user_id: str) -> None: + """Indicate that user_id is no longer erased. + + Args: + user_id: full user_id to be un-erased + """ + + def f(txn): + # first check if they are already in the list + txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,)) + if not txn.fetchone(): + return + + # They are there, delete them. + self.simple_delete_one_txn( + txn, "erased_users", keyvalues={"user_id": user_id} + ) + + self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) + + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/databases/state/__init__.py index 86e09f6229..c90d022899 100644 --- a/synapse/storage/data_stores/state/__init__.py +++ b/synapse/storage/databases/state/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 +from synapse.storage.databases.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index ff000bc9ec..139085b672 100644 --- a/synapse/storage/data_stores/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -15,12 +15,8 @@ import logging -from six import iteritems - -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter @@ -64,7 +60,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self.db.simple_select_one_onecol_txn( + next_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -167,7 +163,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self.db.simple_select_one_onecol_txn( + next_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -184,24 +180,23 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.db.updates.register_background_update_handler( + self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.db.updates.register_background_index_update( + self.db_pool.updates.register_background_index_update( self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, index_name="state_groups_room_id_idx", table="state_groups", columns=["room_id"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -214,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.db.execute( + rows = await self.db_pool.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -280,17 +275,17 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): delta_state = { key: value - for key, value in iteritems(curr_state) + for key, value in curr_state.items() if prev_state.get(key, None) != value } - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="state_group_edges", values={ @@ -299,13 +294,13 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): }, ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -316,7 +311,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(delta_state) + for key, state_id in delta_state.items() ], ) @@ -326,25 +321,24 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): "max_group": max_group, } - self.db.updates._background_update_progress_txn( + self.db_pool.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) return False, batch_size - finished, result = yield self.db.runInteraction( + finished, result = await self.db_pool.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self.db.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) return result * BATCH_SIZE_SCALE_FACTOR - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): + async def _background_index_state(self, progress, batch_size): def reindex_txn(conn): conn.rollback() if isinstance(self.database_engine, PostgresEngine): @@ -367,8 +361,10 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.db.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + await self.db_pool.updates._end_background_update( + self.STATE_GROUP_INDEX_UPDATE_NAME + ) return 1 diff --git a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql index ae09fa0065..ae09fa0065 100644 --- a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql +++ b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql b/synapse/storage/databases/state/schema/delta/30/state_stream.sql index e85699e82e..e85699e82e 100644 --- a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql +++ b/synapse/storage/databases/state/schema/delta/30/state_stream.sql diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql index 1450313bfa..1450313bfa 100644 --- a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql +++ b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql diff --git a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql index 33980d02f0..33980d02f0 100644 --- a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql +++ b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql diff --git a/synapse/storage/data_stores/state/schema/delta/35/state.sql b/synapse/storage/databases/state/schema/delta/35/state.sql index 0f1fa68a89..0f1fa68a89 100644 --- a/synapse/storage/data_stores/state/schema/delta/35/state.sql +++ b/synapse/storage/databases/state/schema/delta/35/state.sql diff --git a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql index 97e5067ef4..97e5067ef4 100644 --- a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql +++ b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py index 9fd1ccf6f7..9fd1ccf6f7 100644 --- a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py +++ b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql index 7916ef18b2..7916ef18b2 100644 --- a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql +++ b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/databases/state/schema/full_schemas/54/full.sql index 35f97d6b3d..35f97d6b3d 100644 --- a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql +++ b/synapse/storage/databases/state/schema/full_schemas/54/full.sql diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres index fcd926c9fb..fcd926c9fb 100644 --- a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres +++ b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/databases/state/store.py index f3ad1e4369..e924f1ca3b 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -17,16 +17,13 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from six import iteritems -from six.moves import range - -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.state import StateFilter +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -54,7 +51,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """A data store for fetching/storing state groups. """ - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(StateGroupDataStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the @@ -95,8 +92,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "*stateGroupMembersCache*", 500000, ) + def get_max_state_group_txn(txn: Cursor): + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + return txn.fetchone()[0] + + self._state_group_seq_gen = build_sequence_generator( + self.database_engine, get_max_state_group_txn, "state_group_id_seq" + ) + @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): + async def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between the old and the new. @@ -105,7 +110,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """ def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( + prev_group = self.db_pool.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -116,7 +121,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self.db.simple_select_list_txn( + delta_ids = self.db_pool.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -128,14 +133,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.db.runInteraction( + return await self.db_pool.runInteraction( "get_state_group_delta", _get_state_group_delta_txn ) - @defer.inlineCallbacks - def _get_state_groups_from_groups( + async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -144,13 +148,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ results = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.db.runInteraction( + res = await self.db_pool.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -199,10 +203,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - @defer.inlineCallbacks - def _get_state_for_groups( + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[int, StateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -212,7 +215,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -221,14 +224,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ( non_member_state, incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( + ) = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -249,7 +249,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Help the cache hit ratio by expanding the filter a bit db_state_filter = state_filter.return_expanded() - group_to_state_dict = yield self._get_state_groups_from_groups( + group_to_state_dict = await self._get_state_groups_from_groups( list(incomplete_groups), state_filter=db_state_filter ) @@ -263,7 +263,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # And finally update the result dict, by filtering out any extra # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): + for group, group_state_dict in group_to_state_dict.items(): # We just replace any existing entries, as we will have loaded # everything we need from the database anyway. state[group] = state_filter.filter_state(group_state_dict) @@ -341,11 +341,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): else: non_member_types = non_member_filter.concrete_types() - for group, group_state_dict in iteritems(group_to_state_dict): + for group, group_state_dict in group_to_state_dict.items(): state_dict_members = {} state_dict_non_members = {} - for k, v in iteritems(group_state_dict): + for k, v in group_state_dict.items(): if k[0] == EventTypes.Member: state_dict_members[k] = v else: @@ -365,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: @@ -381,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ def _store_state_group_txn(txn): @@ -389,9 +389,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # AFAIK, this can never happen raise Exception("current_state_ids cannot be None") - state_group = self.database_engine.get_next_state_group_id(txn) + state_group = self._state_group_seq_gen.get_next_id_txn(txn) - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -400,7 +400,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( + is_in_db = self.db_pool.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -415,13 +415,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( + self.db_pool.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -432,11 +432,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(delta_ids) + for key, state_id in delta_ids.items() ], ) else: - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -447,7 +447,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(current_state_ids) + for key, state_id in current_state_ids.items() ], ) @@ -458,7 +458,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): current_member_state_ids = { s: ev - for (s, ev) in iteritems(current_state_ids) + for (s, ev) in current_state_ids.items() if s[0] == EventTypes.Member } txn.call_after( @@ -470,7 +470,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): current_non_member_state_ids = { s: ev - for (s, ev) in iteritems(current_state_ids) + for (s, ev) in current_state_ids.items() if s[0] != EventTypes.Member } txn.call_after( @@ -482,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_group - return self.db.runInteraction("store_state_group", _store_state_group_txn) + return await self.db_pool.runInteraction( + "store_state_group", _store_state_group_txn + ) - def purge_unreferenced_state_groups( + async def purge_unreferenced_state_groups( self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: + ) -> None: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -497,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): to delete. """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -509,7 +511,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self.db.simple_select_many_txn( + rows = self.db_pool.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -536,15 +538,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self.db.simple_delete_txn( + self.db_pool.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self.db.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -555,7 +557,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(curr_state) + for key, state_id in curr_state.items() ], ) @@ -569,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ((sg,) for sg in state_groups_to_delete), ) - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: """Fetch the previous groups of the given state groups. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. + A mapping from state group to previous state group. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -592,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return {row["state_group"]: row["prev_state_group"] for row in rows} - def purge_room_state(self, room_id, state_groups_to_delete): + async def purge_room_state(self, room_id, state_groups_to_delete): """Deletes all record of a room from state tables Args: @@ -600,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_groups_to_delete (list[int]): State groups to delete """ - return self.db.runInteraction( + await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, @@ -611,7 +613,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -622,7 +624,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -633,7 +635,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self.db.simple_delete_many_txn( + self.db_pool.simple_delete_many_txn( txn, table="state_groups", column="id", diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ab0bbe4bd3..908cbc79e3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): def lock_table(self, txn, table: str) -> None: ... - @abc.abstractmethod - def get_next_state_group_id(self, txn) -> int: - """Returns an int that can be used as a new state_group ID - """ - ... - @property @abc.abstractmethod def server_version(self) -> str: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 6c7d08a6f2..ff39281f85 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine): errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,)) if ctype != "C": - errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,)) + errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,)) if errors: raise IncorrectDatabaseSetup( @@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - txn.execute("SELECT nextval('state_group_id_seq')") - return txn.fetchone()[0] - @property def server_version(self): """Returns a string giving the server version. For example: '8.1.5' diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 215a949442..8a0f8c89d1 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def lock_table(self, txn, table): return - def get_next_state_group_id(self, txn): - """Returns an int that can be used as a new state_group ID - """ - # We do application locking here since if we're using sqlite then - # we are a single process synapse. - with self._current_state_group_id_lock: - if self._current_state_group_id is None: - txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") - self._current_state_group_id = txn.fetchone()[0] - - self._current_state_group_id += 1 - return self._current_state_group_id - @property def server_version(self): """Gets a string giving the server version. For example: '3.22.0' diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 4769b21529..afd10f7bae 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -22,6 +22,6 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True) -class FetchKeyResult(object): +class FetchKeyResult: verify_key = attr.ib() # VerifyKey: the key itself valid_until_ts = attr.ib() # int: how long we can use this key for diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index f159400a87..dbaeef91dd 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -20,21 +20,17 @@ import logging from collections import deque, namedtuple from typing import Iterable, List, Optional, Set, Tuple -from six import iteritems -from six.moves import range - from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import FrozenEvent +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore -from synapse.storage.data_stores import DataStores -from synapse.storage.data_stores.main.events import DeltaState +from synapse.storage.databases import Databases +from synapse.storage.databases.main.events import DeltaState from synapse.types import StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -73,7 +69,7 @@ stale_forward_extremities_counter = Histogram( ) -class _EventPeristenceQueue(object): +class _EventPeristenceQueue: """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. """ @@ -176,14 +172,14 @@ class _EventPeristenceQueue(object): pass -class EventsPersistenceStorage(object): +class EventsPersistenceStorage: """High level interface for handling persisting newly received events. Takes care of batching up events by room, and calculating the necessary current state and forward extremity changes. """ - def __init__(self, hs, stores: DataStores): + def __init__(self, hs, stores: Databases): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. @@ -196,12 +192,11 @@ class EventsPersistenceStorage(object): self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def persist_events( + async def persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> int: """ Write events to the database Args: @@ -211,14 +206,14 @@ class EventsPersistenceStorage(object): which might update the current state etc. Returns: - Deferred[int]: the stream ordering of the latest persisted event + the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): + for room_id, evs_ctxs in partitioned.items(): d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled ) @@ -227,22 +222,19 @@ class EventsPersistenceStorage(object): for room_id in partitioned: self._maybe_start_persisting(room_id) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) - max_persisted_id = yield self.main_store.get_current_events_token() + return self.main_store.get_current_events_token() - return max_persisted_id - - @defer.inlineCallbacks - def persist_event( - self, event: FrozenEvent, context: EventContext, backfilled: bool = False - ): + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[int, int]: """ Returns: - Deferred[Tuple[int, int]]: the stream ordering of ``event``, - and the stream ordering of the latest persisted event + The stream ordering of `event`, and the stream ordering of the + latest persisted event """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -250,9 +242,9 @@ class EventsPersistenceStorage(object): self._maybe_start_persisting(event.room_id) - yield make_deferred_yieldable(deferred) + await make_deferred_yieldable(deferred) - max_persisted_id = yield self.main_store.get_current_events_token() + max_persisted_id = self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id: str): @@ -266,7 +258,7 @@ class EventsPersistenceStorage(object): async def _persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, ): """Calculates the change to current state and forward extremities, and @@ -319,7 +311,7 @@ class EventsPersistenceStorage(object): (event, context) ) - for room_id, ev_ctx_rm in iteritems(events_by_room): + for room_id, ev_ctx_rm in events_by_room.items(): latest_event_ids = await self.main_store.get_latest_event_ids_in_room( room_id ) @@ -443,7 +435,7 @@ class EventsPersistenceStorage(object): async def _calculate_new_extremities( self, room_id: str, - event_contexts: List[Tuple[FrozenEvent, EventContext]], + event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: List[str], ): """Calculates the new forward extremities for a room given events to @@ -501,7 +493,7 @@ class EventsPersistenceStorage(object): async def _get_new_state_after_events( self, room_id: str, - events_context: List[Tuple[FrozenEvent, EventContext]], + events_context: List[Tuple[EventBase, EventContext]], old_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: @@ -651,6 +643,10 @@ class EventsPersistenceStorage(object): room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, @@ -674,7 +670,7 @@ class EventsPersistenceStorage(object): to_insert = { key: ev_id - for key, ev_id in iteritems(current_state) + for key, ev_id in current_state.items() if ev_id != existing_state.get(key) } @@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, current_state: Optional[StateMap[str]], potentially_left_users: Set[str], @@ -786,9 +782,3 @@ class EventsPersistenceStorage(object): for user_id in left_users: await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) - - async def locally_reject_invite(self, user_id: str, room_id: str) -> int: - """Mark the invite has having been rejected even though we failed to - create a leave event for it. - """ - return await self.persist_events_store.locally_reject_invite(user_id, room_id) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 9cc3b51fe6..ee60e2a718 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -47,8 +47,24 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): - """Prepares a database for usage. Will either create all necessary tables +OUTDATED_SCHEMA_ON_WORKER_ERROR = ( + "Expected database schema version %i but got %i: run the main synapse process to " + "upgrade the database schema before starting worker processes." +) + +EMPTY_DATABASE_ON_WORKER_ERROR = ( + "Uninitialised database: run the main synapse process to prepare the database " + "schema before starting worker processes." +) + +UNAPPLIED_DELTA_ON_WORKER_ERROR = ( + "Database schema delta %s has not been applied: run the main synapse process to " + "upgrade the database schema before starting worker processes." +) + + +def prepare_database(db_conn, database_engine, config, databases=["main", "state"]): + """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. If `config` is None then prepare_database will assert that no upgrade is @@ -60,37 +76,57 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already - data_stores (list[str]): The name of the data stores that will be used - with this database. Defaults to all data stores. + databases (list[str]): The name of the databases that will be used + with this physical database. Defaults to all databases. """ try: cur = db_conn.cursor() + + logger.info("%r: Checking existing schema version", databases) version_info = _get_or_create_schema_state(cur, database_engine) if version_info: user_version, delta_files, upgraded = version_info + logger.info( + "%r: Existing schema is %i (+%i deltas)", + databases, + user_version, + len(delta_files), + ) + # config should only be None when we are preparing an in-memory SQLite db, + # which should be empty. if config is None: - if user_version != SCHEMA_VERSION: - # If we don't pass in a config file then we are expecting to - # have already upgraded the DB. - raise UpgradeDatabaseException( - "Expected database schema version %i but got %i" - % (SCHEMA_VERSION, user_version) - ) - else: - _upgrade_existing_database( - cur, - user_version, - delta_files, - upgraded, - database_engine, - config, - data_stores=data_stores, + raise ValueError( + "config==None in prepare_database, but databse is not empty" + ) + + # if it's a worker app, refuse to upgrade the database, to avoid multiple + # workers doing it at once. + if config.worker_app is not None and user_version != SCHEMA_VERSION: + raise UpgradeDatabaseException( + OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version) ) + + _upgrade_existing_database( + cur, + user_version, + delta_files, + upgraded, + database_engine, + config, + databases=databases, + ) else: - _setup_new_database(cur, database_engine, data_stores=data_stores) + logger.info("%r: Initialising new database", databases) + + # if it's a worker app, refuse to upgrade the database, to avoid multiple + # workers doing it at once. + if config and config.worker_app is not None: + raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR) + + _setup_new_database(cur, database_engine, databases=databases) # check if any of our configured dynamic modules want a database if config is not None: @@ -103,9 +139,9 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta raise -def _setup_new_database(cur, database_engine, data_stores): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas, including schemas from the given data +def _setup_new_database(cur, database_engine, databases): + """Sets up the physical database by finding a base set of "full schemas" and + then applying any necessary deltas, including schemas from the given data stores. The "full_schemas" directory has subdirectories named after versions. This @@ -138,8 +174,8 @@ def _setup_new_database(cur, database_engine, data_stores): Args: cur (Cursor): a database cursor database_engine (DatabaseEngine) - data_stores (list[str]): The names of the data stores to instantiate - on the given database. + databases (list[str]): The names of the databases to instantiate + on the given physical database. """ # We're about to set up a brand new database so we check that its @@ -176,13 +212,13 @@ def _setup_new_database(cur, database_engine, data_stores): directories.extend( os.path.join( dir_path, - "data_stores", - data_store, + "databases", + database, "schema", "full_schemas", str(max_current_ver), ) - for data_store in data_stores + for database in databases ) directory_entries = [] @@ -219,7 +255,7 @@ def _setup_new_database(cur, database_engine, data_stores): upgraded=False, database_engine=database_engine, config=None, - data_stores=data_stores, + databases=databases, is_empty=True, ) @@ -231,10 +267,10 @@ def _upgrade_existing_database( upgraded, database_engine, config, - data_stores, + databases, is_empty=False, ): - """Upgrades an existing database. + """Upgrades an existing physical database. Delta files can either be SQL stored in *.sql files, or python modules in *.py. @@ -285,8 +321,8 @@ def _upgrade_existing_database( config (synapse.config.homeserver.HomeServerConfig|None): None if we are initialising a blank database, otherwise the application config - data_stores (list[str]): The names of the data stores to instantiate - on the given database. + databases (list[str]): The names of the databases to instantiate + on the given physical database. is_empty (bool): Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ @@ -295,6 +331,8 @@ def _upgrade_existing_database( else: assert config + is_worker = config and config.worker_app is not None + if current_version > SCHEMA_VERSION: raise ValueError( "Cannot use this database as it is too " @@ -303,8 +341,8 @@ def _upgrade_existing_database( # some of the deltas assume that config.server_name is set correctly, so now # is a good time to run the sanity check. - if not is_empty and "main" in data_stores: - from synapse.storage.data_stores.main import check_database_before_upgrade + if not is_empty and "main" in databases: + from synapse.storage.databases.main import check_database_before_upgrade check_database_before_upgrade(cur, database_engine, config) @@ -322,7 +360,7 @@ def _upgrade_existing_database( specific_engine_extensions = (".sqlite", ".postgres") for v in range(start_ver, SCHEMA_VERSION + 1): - logger.info("Upgrading schema to v%d", v) + logger.info("Applying schema deltas for v%d", v) # We need to search both the global and per data store schema # directories for schema updates. @@ -330,11 +368,9 @@ def _upgrade_existing_database( # First we find the directories to search in delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) directories = [delta_dir] - for data_store in data_stores: + for database in databases: directories.append( - os.path.join( - dir_path, "data_stores", data_store, "schema", "delta", str(v) - ) + os.path.join(dir_path, "databases", database, "schema", "delta", str(v)) ) # Used to check if we have any duplicate file names @@ -384,9 +420,15 @@ def _upgrade_existing_database( continue root_name, ext = os.path.splitext(file_name) + if ext == ".py": # This is a python upgrade module. We need to import into some # package and then execute its `run_upgrade` function. + if is_worker: + raise PrepareDatabaseException( + UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path + ) + module_name = "synapse.storage.v%d_%s" % (v, root_name) with open(absolute_path) as python_file: module = imp.load_source(module_name, absolute_path, python_file) @@ -401,10 +443,18 @@ def _upgrade_existing_database( continue elif ext == ".sql": # A plain old .sql file, just read and execute it + if is_worker: + raise PrepareDatabaseException( + UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path + ) logger.info("Applying schema %s", relative_path) executescript(cur, absolute_path) elif ext == specific_engine_extension and root_name.endswith(".sql"): # A .sql file specific to our engine; just read and execute it + if is_worker: + raise PrepareDatabaseException( + UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path + ) logger.info("Applying engine-specific schema %s", relative_path) executescript(cur, absolute_path) elif ext in specific_engine_extensions and root_name.endswith(".sql"): @@ -434,6 +484,8 @@ def _upgrade_existing_database( (v, True), ) + logger.info("Schema now up to date") + def _apply_module_schemas(txn, database_engine, config): """Apply the module schemas for the dynamic modules, if any @@ -571,7 +623,7 @@ def _get_or_create_schema_state(txn, database_engine): @attr.s() -class _DirectoryListing(object): +class _DirectoryListing: """Helper class to store schema file name and the absolute path to it. diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index fdc0abf5cf..bfa0a9fd06 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -15,62 +15,60 @@ import itertools import logging - -from twisted.internet import defer +from typing import Set logger = logging.getLogger(__name__) -class PurgeEventsStorage(object): +class PurgeEventsStorage: """High level interface for purging rooms and event history. """ def __init__(self, hs, stores): self.stores = stores - @defer.inlineCallbacks - def purge_room(self, room_id: str): + async def purge_room(self, room_id: str): """Deletes all record of a room """ - state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - @defer.inlineCallbacks - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: """Deletes room history before a certain point Args: - room_id (str): + room_id: The room ID - token (str): A topological token to delete events before + token: A topological token to delete events before - delete_local_events (bool): + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). """ - state_groups = yield self.stores.main.purge_history( + state_groups = await self.stores.main.purge_history( room_id, token, delete_local_events ) logger.info("[purge] finding state groups that can be deleted") - sg_to_delete = yield self._find_unreferenced_groups(state_groups) + sg_to_delete = await self._find_unreferenced_groups(state_groups) - yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - @defer.inlineCallbacks - def _find_unreferenced_groups(self, state_groups): + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. Args: - state_groups (set[int]): Set of state groups referenced by events + state_groups: Set of state groups referenced by events that are going to be deleted. Returns: - Deferred[set[int]] The set of state groups that can be deleted. + The set of state groups that can be deleted. """ # Graph of state group -> previous group graph = {} @@ -93,7 +91,7 @@ class PurgeEventsStorage(object): current_search = set(itertools.islice(next_to_search, 100)) next_to_search -= current_search - referenced = yield self.stores.main.get_referenced_state_groups( + referenced = await self.stores.main.get_referenced_state_groups( current_search ) referenced_groups |= referenced @@ -102,7 +100,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.state.get_previous_state_groups(current_search) + edges = await self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index d471ec9860..d30e3f11e7 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) @attr.s -class PaginationChunk(object): +class PaginationChunk: """Returned by relation pagination APIs. Attributes: @@ -51,7 +51,7 @@ class PaginationChunk(object): @attr.s(frozen=True, slots=True) -class RelationPaginationToken(object): +class RelationPaginationToken: """Pagination token for relation pagination API. As the results are in topological order, we can use the @@ -82,7 +82,7 @@ class RelationPaginationToken(object): @attr.s(frozen=True, slots=True) -class AggregationPaginationToken(object): +class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. As the results are order by count and then MAX(stream_ordering) of the diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c522c80922..8f68d968f0 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,15 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar - -from six import iteritems, itervalues +from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -32,58 +29,57 @@ T = TypeVar("T") @attr.s(slots=True) -class StateFilter(object): +class StateFilter: """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: - self.types = {k: v for k, v in iteritems(self.types) if v is not None} + self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -93,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -132,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -150,7 +146,7 @@ class StateFilter(object): has_non_member_wildcard = self.include_others or any( state_keys is None - for t, state_keys in iteritems(self.types) + for t, state_keys in self.types.items() if t != EventTypes.Member ) @@ -169,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -181,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -199,7 +194,7 @@ class StateFilter(object): # First we build up a lost of clauses for each type/state_key combo clauses = [] - for etype, state_keys in iteritems(self.types): + for etype, state_keys in self.types.items(): if state_keys is None: clauses.append("(type = ?)") where_args.append(etype) @@ -223,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -251,7 +246,7 @@ class StateFilter(object): return dict(state_dict) filtered_state = {} - for k, v in iteritems(state_dict): + for k, v in state_dict.items(): typ, state_key = k if typ in self.types: state_keys = self.types[typ] @@ -262,42 +257,42 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( - state_keys is None for state_keys in itervalues(self.types) + state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) - for t, state_keys in iteritems(self.types) + for t, state_keys in self.types.items() if state_keys is not None for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -309,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -324,84 +319,91 @@ class StateFilter(object): member_filter = StateFilter.none() non_member_filter = StateFilter( - types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + types={k: v for k, v in self.types.items() if k != EventTypes.Member}, include_others=self.include_others, ) return member_filter, non_member_filter -class StateGroupStorage(object): +class StateGroupStorage: """High level interface to fetching state for event. """ def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group: int): + async def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: - Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: + Tuple[Optional[int], Optional[StateMap[str]]]: (prev_group, delta_ids) """ - return self.stores.state.get_state_group_delta(state_group) + return await self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) - groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) + for group_ids in group_to_ids.values() + for ev_id in group_ids.values() ], get_prev_content=False, ) @@ -409,15 +411,15 @@ class StateGroupStorage(object): return { group: [ state_event_map[v] - for v in itervalues(event_id_map) + for v in event_id_map.values() if v in state_event_map ] - for group, event_id_map in iteritems(group_to_ids) + for group, event_id_map in group_to_ids.items() } def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -425,140 +427,148 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) - groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.state._get_state_for_groups( + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) + for k, v in group_to_state[group].items() if v in state_event_map } - for event_id, group in iteritems(event_to_groups) + for event_id, group in event_to_groups.items() } return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) - groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.state._get_state_for_groups( + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) event_to_state = { event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) + for event_id, group in event_to_groups.items() } return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. + Returns: - Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): + async def store_state_group( + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, + ) -> int: """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: - Deferred[int]: The state group ID + The state group ID """ - return self.stores.state.store_state_group( + return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) diff --git a/synapse/storage/types.py b/synapse/storage/types.py index daff81c5ee..2d2b560e74 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,12 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Iterable, Iterator, List, Tuple from typing_extensions import Protocol - """ Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f89ce0bed2..b7eb4f8ac9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -14,16 +14,21 @@ # limitations under the License. import contextlib +import heapq +import logging import threading from collections import deque -from typing import Dict, Set, Tuple +from typing import Dict, List, Set from typing_extensions import Deque -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.util.sequence import PostgresSequenceGenerator +logger = logging.getLogger(__name__) -class IdGenerator(object): + +class IdGenerator: def __init__(self, db_conn, table, column): self._lock = threading.Lock() self._next_id = _load_current_id(db_conn, table, column) @@ -46,6 +51,8 @@ def _load_current_id(db_conn, table, column, step=1): Returns: int """ + # debug logging for https://github.com/matrix-org/synapse/issues/7968 + logger.info("initialising stream generator for %s(%s)", table, column) cur = db_conn.cursor() if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) @@ -57,7 +64,7 @@ def _load_current_id(db_conn, table, column, step=1): return (max if step > 0 else min)(current_id, step) -class StreamIdGenerator(object): +class StreamIdGenerator: """Used to generate new stream ids when persisting events while keeping track of which transactions have been completed. @@ -79,7 +86,7 @@ class StreamIdGenerator(object): upwards, -1 to grow downwards. Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -94,10 +101,10 @@ class StreamIdGenerator(object): ) self._unfinished_ids = deque() # type: Deque[int] - def get_next(self): + async def get_next(self): """ Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -116,10 +123,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, n): + async def get_next_mult(self, n): """ Usage: - with stream_id_gen.get_next(n) as stream_ids: + with await stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -157,63 +164,13 @@ class StreamIdGenerator(object): return self._current + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. -class ChainedIdGenerator(object): - """Used to generate new stream ids where the stream must be kept in sync - with another stream. It generates pairs of IDs, the first element is an - integer ID for this stream, the second element is the ID for the stream - that this stream needs to be kept in sync with.""" - - def __init__(self, chained_generator, db_conn, table, column): - self.chained_generator = chained_generator - self._table = table - self._lock = threading.Lock() - self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] - - def get_next(self): - """ - Usage: - with stream_id_gen.get_next() as (stream_id, chained_id): - # ... persist event ... - """ - with self._lock: - self._current_max += 1 - next_id = self._current_max - chained_id = self.chained_generator.get_current_token() - - self._unfinished_ids.append((next_id, chained_id)) - - @contextlib.contextmanager - def manager(): - try: - yield (next_id, chained_id) - finally: - with self._lock: - self._unfinished_ids.remove((next_id, chained_id)) - - return manager() - - def get_current_token(self): - """Returns the maximum stream id such that all stream ids less than or - equal to it have been successfully persisted. - """ - with self._lock: - if self._unfinished_ids: - stream_id, chained_id = self._unfinished_ids[0] - return stream_id - 1, chained_id - - return self._current_max, self.chained_generator.get_current_token() - - def advance(self, token: int): - """Stub implementation for advancing the token when receiving updates - over replication; raises an exception as this instance should be the - only source of updates. + For streams with single writers this is equivalent to + `get_current_token`. """ - - raise Exception( - "Attempted to advance token on source for table %r", self._table - ) + return self.get_current_token() class MultiWriterIdGenerator: @@ -233,25 +190,32 @@ class MultiWriterIdGenerator: id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + positive: Whether the IDs are positive (true) or negative (false). + When using negative IDs we go backwards from -1 to -2, -3, etc. """ def __init__( self, db_conn, - db: Database, + db: DatabasePool, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + positive: bool = True, ): self._db = db self._instance_name = instance_name - self._sequence_name = sequence_name + self._positive = positive + self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. self._lock = threading.Lock() + # Note: If we are a negative stream then we still store all the IDs as + # positive to make life easier for us, and simply negate the IDs when we + # return them. self._current_positions = self._load_current_ids( db_conn, table, instance_column, id_column ) @@ -260,16 +224,38 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been + # persisted. This is done by a) looking at the min across all instances + # and b) noting that if we have seen a run of persisted positions + # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7). + # + # Note: There is no guarentee that the IDs generated by the sequence + # will be gapless; gaps can form when e.g. a transaction was rolled + # back. This means that sometimes we won't be able to skip forward the + # position even though everything has been persisted. However, since + # gaps should be relatively rare it's still worth doing the book keeping + # that allows us to skip forwards when there are gapless runs of + # positions. + self._persisted_upto_position = ( + min(self._current_positions.values()) if self._current_positions else 0 + ) + self._known_persisted_positions = [] # type: List[int] + + self._sequence_gen = PostgresSequenceGenerator(sequence_name) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: + # If positive stream aggregate via MAX. For negative stream use MIN + # *and* negate the result to get a positive number. sql = """ - SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s GROUP BY %(instance)s """ % { "instance": instance_column, "id": id_column, "table": table, + "agg": "MAX" if self._positive else "-MIN", } cur = db_conn.cursor() @@ -282,10 +268,11 @@ class MultiWriterIdGenerator: return current_positions - def _load_next_id_txn(self, txn): - txn.execute("SELECT nextval(?)", (self._sequence_name,)) - (next_id,) = txn.fetchone() - return next_id + def _load_next_id_txn(self, txn) -> int: + return self._sequence_gen.get_next_id_txn(txn) + + def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: + return self._sequence_gen.get_next_mult_txn(txn, n) async def get_next(self): """ @@ -298,20 +285,49 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token() < next_id - with self._lock: + assert self._current_positions.get(self._instance_name, 0) < next_id + self._unfinished_ids.add(next_id) @contextlib.contextmanager def manager(): try: - yield next_id + # Multiply by the return factor so that the ID has correct sign. + yield self._return_factor * next_id finally: self._mark_id_as_finished(next_id) return manager() + async def get_next_mult(self, n: int): + """ + Usage: + with await stream_id_gen.get_next_mult(5) as stream_ids: + # ... persist events ... + """ + next_ids = await self._db.runInteraction( + "_load_next_mult_id", self._load_next_mult_id_txn, n + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self._lock: + assert max(self._current_positions.values(), default=0) < min(next_ids) + + self._unfinished_ids.update(next_ids) + + @contextlib.contextmanager + def manager(): + try: + yield [self._return_factor * i for i in next_ids] + finally: + for i in next_ids: + self._mark_id_as_finished(i) + + return manager() + def get_next_txn(self, txn: LoggingTransaction): """ Usage: @@ -328,7 +344,7 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) - return next_id + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the @@ -344,33 +360,92 @@ class MultiWriterIdGenerator: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, next_id) - def get_current_token(self, instance_name: str = None) -> int: - """Gets the current position of a named writer (defaults to current - instance). + self._add_persisted_position(next_id) - Returns 0 if we don't have a position for the named writer (likely due - to it being a new writer). + def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. """ - if instance_name is None: - instance_name = self._instance_name + # Currently we don't support this operation, as it's not obvious how to + # condense the stream positions of multiple writers into a single int. + raise NotImplementedError() + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + """ with self._lock: - return self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get(instance_name, 0) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. """ with self._lock: - return dict(self._current_positions) + return { + name: self._return_factor * i + for name, i in self._current_positions.items() + } def advance(self, instance_name: str, new_id: int): """Advance the postion of the named writer to the given ID, if greater than existing entry. """ + new_id *= self._return_factor + with self._lock: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) ) + + self._add_persisted_position(new_id) + + def get_persisted_upto_position(self) -> int: + """Get the max position where all previous positions have been + persisted. + + Note: In the worst case scenario this will be equal to the minimum + position across writers. This means that the returned position here can + lag if one writer doesn't write very often. + """ + + with self._lock: + return self._return_factor * self._persisted_upto_position + + def _add_persisted_position(self, new_id: int): + """Record that we have persisted a position. + + This is used to keep the `_current_positions` up to date. + """ + + # We require that the lock is locked by caller + assert self._lock.locked() + + heapq.heappush(self._known_persisted_positions, new_id) + + # We move the current min position up if the minimum current positions + # of all instances is higher (since by definition all positions less + # that that have been persisted). + min_curr = min(self._current_positions.values()) + self._persisted_upto_position = max(min_curr, self._persisted_upto_position) + + # We now iterate through the seen positions, discarding those that are + # less than the current min positions, and incrementing the min position + # if its exactly one greater. + # + # This is also where we discard items from `_known_persisted_positions` + # (to ensure the list doesn't infinitely grow). + while self._known_persisted_positions: + if self._known_persisted_positions[0] <= self._persisted_upto_position: + heapq.heappop(self._known_persisted_positions) + elif ( + self._known_persisted_positions[0] == self._persisted_upto_position + 1 + ): + heapq.heappop(self._known_persisted_positions) + self._persisted_upto_position += 1 + else: + # There was a gap in seen positions, so there is nothing more to + # do. + break diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py new file mode 100644 index 0000000000..ffc1894748 --- /dev/null +++ b/synapse/storage/util/sequence.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import threading +from typing import Callable, List, Optional + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.types import Cursor + + +class SequenceGenerator(metaclass=abc.ABCMeta): + """A class which generates a unique sequence of integers""" + + @abc.abstractmethod + def get_next_id_txn(self, txn: Cursor) -> int: + """Gets the next ID in the sequence""" + ... + + +class PostgresSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses a postgres sequence""" + + def __init__(self, sequence_name: str): + self._sequence_name = sequence_name + + def get_next_id_txn(self, txn: Cursor) -> int: + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + return txn.fetchone()[0] + + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + txn.execute( + "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n) + ) + return [i for (i,) in txn] + + +GetFirstCallbackType = Callable[[Cursor], int] + + +class LocalSequenceGenerator(SequenceGenerator): + """An implementation of SequenceGenerator which uses local locking + + This only works reliably if there are no other worker processes generating IDs at + the same time. + """ + + def __init__(self, get_first_callback: GetFirstCallbackType): + """ + Args: + get_first_callback: a callback which is called on the first call to + get_next_id_txn; should return the curreent maximum id + """ + # the callback. this is cleared after it is called, so that it can be GCed. + self._callback = get_first_callback # type: Optional[GetFirstCallbackType] + + # The current max value, or None if we haven't looked in the DB yet. + self._current_max_id = None # type: Optional[int] + self._lock = threading.Lock() + + def get_next_id_txn(self, txn: Cursor) -> int: + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + self._current_max_id += 1 + return self._current_max_id + + +def build_sequence_generator( + database_engine: BaseDatabaseEngine, + get_first_callback: GetFirstCallbackType, + sequence_name: str, +) -> SequenceGenerator: + """Get the best impl of SequenceGenerator available + + This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on + sqlite. + + Args: + database_engine: the database engine we are connected to + get_first_callback: a callback which gets the next sequence ID. Used if + we're on sqlite. + sequence_name: the name of a postgres sequence to use. + """ + if isinstance(database_engine, PostgresEngine): + return PostgresSequenceGenerator(sequence_name) + else: + return LocalSequenceGenerator(get_first_callback) diff --git a/synapse/streams/config.py b/synapse/streams/config.py index cd56cd91ed..d97dc4d101 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) MAX_LIMIT = 1000 -class SourcePaginationConfig(object): +class SourcePaginationConfig: """A configuration object which stores pagination parameters for a specific event source.""" @@ -45,7 +45,7 @@ class SourcePaginationConfig(object): ) -class PaginationConfig(object): +class PaginationConfig: """A configuration object which stores pagination parameters.""" @@ -68,13 +68,13 @@ class PaginationConfig(object): elif from_tok: from_tok = StreamToken.from_string(from_tok) except Exception: - raise SynapseError(400, "'from' paramater is invalid") + raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) except Exception: - raise SynapseError(400, "'to' paramater is invalid") + raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index fcd2aaa9c9..92fd5d489f 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -15,8 +15,6 @@ from typing import Any, Dict -from twisted.internet import defer - from synapse.handlers.account_data import AccountDataEventSource from synapse.handlers.presence import PresenceEventSource from synapse.handlers.receipts import ReceiptEventSource @@ -25,7 +23,7 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.types import StreamToken -class EventSources(object): +class EventSources: SOURCE_TYPES = { "room": RoomEventSource, "presence": PresenceEventSource, @@ -40,19 +38,18 @@ class EventSources(object): } # type: Dict[str, Any] self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_current_token(self): - push_rules_key, _ = self.store.get_push_rules_stream_token() + def get_current_token(self) -> StreamToken: + push_rules_key = self.store.get_max_push_rules_stream_id() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() groups_key = self.store.get_group_stream_token() token = StreamToken( - room_key=(yield self.sources["room"].get_current_key()), - presence_key=(yield self.sources["presence"].get_current_key()), - typing_key=(yield self.sources["typing"].get_current_key()), - receipt_key=(yield self.sources["receipt"].get_current_key()), - account_data_key=(yield self.sources["account_data"].get_current_key()), + room_key=self.sources["room"].get_current_key(), + presence_key=self.sources["presence"].get_current_key(), + typing_key=self.sources["typing"].get_current_key(), + receipt_key=self.sources["receipt"].get_current_key(), + account_data_key=self.sources["account_data"].get_current_key(), push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, @@ -60,19 +57,18 @@ class EventSources(object): ) return token - @defer.inlineCallbacks - def get_current_token_for_pagination(self): + def get_current_token_for_pagination(self) -> StreamToken: """Get the current token for a given room to be used to paginate events. The returned token does not have the current values for fields other than `room`, since they are not used during pagination. - Retuns: - Deferred[StreamToken] + Returns: + The current token for pagination. """ token = StreamToken( - room_key=(yield self.sources["room"].get_current_key()), + room_key=self.sources["room"].get_current_key(), presence_key=0, typing_key=0, receipt_key=0, diff --git a/synapse/types.py b/synapse/types.py index acf60baddc..f7de48f148 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,11 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import re import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, TypeVar +from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -29,19 +30,20 @@ from synapse.api.errors import Codes, SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container + from typing import Container, Iterable, Sized T_co = TypeVar("T_co", covariant=True) - class Collection(Iterable[T_co], Container[T_co], Sized): + class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore __slots__ = () # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") -StateMap = Dict[Tuple[str, str], T] - +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] # the type of a JSON-serialisable dict. This could be made stronger, but it will # do for now. @@ -50,7 +52,15 @@ JsonDict = Dict[str, Any] class Requester( namedtuple( - "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] + "Requester", + [ + "user", + "access_token_id", + "is_guest", + "shadow_banned", + "device_id", + "app_service", + ], ) ): """ @@ -61,6 +71,7 @@ class Requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request has been shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user """ @@ -76,6 +87,7 @@ class Requester( "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, } @@ -100,13 +112,19 @@ class Requester( user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, ) def create_requester( - user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None + user_id, + access_token_id=None, + is_guest=False, + shadow_banned=False, + device_id=None, + app_service=None, ): """ Create a new ``Requester`` object @@ -116,6 +134,7 @@ def create_requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user @@ -124,7 +143,9 @@ def create_requester( """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id, app_service) + return Requester( + user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + ) def get_domain_from_id(string): @@ -141,6 +162,9 @@ def get_localpart_from_id(string): return string[1:idx] +DS = TypeVar("DS", bound="DomainSpecificString") + + class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -151,6 +175,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ + __metaclass__ = abc.ABCMeta + + SIGIL = abc.abstractproperty() # type: str # type: ignore + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) @@ -166,7 +194,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s: str): + def from_string(cls: Type[DS], s: str) -> DS: """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( @@ -190,12 +218,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom # names on one HS return cls(localpart=parts[0], domain=domain) - def to_string(self): + def to_string(self) -> str: """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) @classmethod - def is_valid(cls, s): + def is_valid(cls: Type[DS], s: str) -> bool: try: cls.from_string(s) return True @@ -235,8 +263,9 @@ class GroupID(DomainSpecificString): SIGIL = "+" @classmethod - def from_string(cls, s): - group_id = super(GroupID, cls).from_string(s) + def from_string(cls: Type[DS], s: str) -> DS: + group_id = super().from_string(s) # type: DS # type: ignore + if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) @@ -500,7 +529,7 @@ class ThirdPartyInstanceID( @attr.s(slots=True) -class ReadReceipt(object): +class ReadReceipt: """Information about a read-receipt""" room_id = attr.ib() diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 60f0de70f7..a13f11f8d8 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -17,6 +17,7 @@ import logging import re import attr +from canonicaljson import json from twisted.internet import defer, task @@ -25,6 +26,19 @@ from synapse.logging import context logger = logging.getLogger(__name__) +def _reject_invalid_json(val): + """Do not allow Infinity, -Infinity, or NaN values in JSON.""" + raise ValueError("Invalid JSON value: '%s'" % val) + + +# Create a custom encoder to reduce the whitespace produced by JSON encoding and +# ensure that valid JSON is produced. +json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) + +# Create a custom decoder to reject Python extensions to JSON. +json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) + + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) @@ -32,7 +46,7 @@ def unwrapFirstError(failure): @attr.s -class Clock(object): +class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. @@ -55,7 +69,7 @@ class Clock(object): return self._reactor.seconds() def time_msec(self): - """Returns the current system time in miliseconds since epoch.""" + """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) def looping_call(self, f, msec, *args, **kwargs): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f7af2bca7f..bb57e27beb 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -19,9 +19,8 @@ import logging from contextlib import contextmanager from typing import Dict, Sequence, Set, Union -from six.moves import range - import attr +from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -37,7 +36,7 @@ from synapse.util import Clock, unwrapFirstError logger = logging.getLogger(__name__) -class ObservableDeferred(object): +class ObservableDeferred: """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original deferred. @@ -95,7 +94,7 @@ class ObservableDeferred(object): This returns a brand new deferred that is resolved when the underlying deferred is resolved. Interacting with the returned deferred does not - effect the underdlying deferred. + effect the underlying deferred. """ if not self._result: d = defer.Deferred() @@ -189,7 +188,7 @@ def yieldable_gather_results(func, iter, *args, **kwargs): ).addErrback(unwrapFirstError) -class Linearizer(object): +class Linearizer: """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. @@ -339,12 +338,12 @@ class Linearizer(object): return new_defer -class ReadWriteLock(object): - """A deferred style read write lock. +class ReadWriteLock: + """An async read write lock. Example: - with (yield read_write_lock.read("test_key")): + with await read_write_lock.read("test_key"): # do some work """ @@ -354,7 +353,7 @@ class ReadWriteLock(object): # resolved when they release the lock). # # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. + # been resolved. The new reader is appended to the list of latest readers. # # Write: We know its safe to acquire the write lock when both the latest # writers and readers have been resolved. The new writer replaces the latest @@ -367,8 +366,7 @@ class ReadWriteLock(object): # Latest writer queued self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] - @defer.inlineCallbacks - def read(self, key): + async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) @@ -378,7 +376,8 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) + if curr_writer: + await make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -390,8 +389,7 @@ class ReadWriteLock(object): return _ctx_manager() - @defer.inlineCallbacks - def write(self, key): + async def write(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) @@ -407,7 +405,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): @@ -504,7 +502,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): @attr.s(slots=True, frozen=True) -class DoneAwaitable(object): +class DoneAwaitable: """Simple awaitable that returns the provided value. """ diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index dd356bf156..237f588658 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -43,7 +43,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n @attr.s -class CacheMetric(object): +class CacheMetric: _cache = attr.ib() _cache_type = attr.ib(type=str) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index cd48262420..98b34f2223 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,13 +18,10 @@ import functools import inspect import logging import threading -from typing import Any, Tuple, Union, cast +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary -from six import itervalues - from prometheus_client import Gauge -from typing_extensions import Protocol from twisted.internet import defer @@ -40,8 +37,10 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] +F = TypeVar("F", bound=Callable[..., Any]) + -class _CachedFunction(Protocol): +class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any invalidate_many = None # type: Any @@ -49,8 +48,11 @@ class _CachedFunction(Protocol): cache = None # type: Any num_args = None # type: Any - def __name__(self): - ... + __name__ = None # type: str + + # Note: This function signature is actually fiddled with by the synapse mypy + # plugin to a) make it a bound method, and b) remove any `cache_context` arg. + __call__ = None # type: F cache_pending_metric = Gauge( @@ -62,7 +64,7 @@ cache_pending_metric = Gauge( _CacheSentinel = object() -class CacheEntry(object): +class CacheEntry: __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): @@ -78,7 +80,7 @@ class CacheEntry(object): self.callbacks.clear() -class Cache(object): +class Cache: __slots__ = ( "cache", "name", @@ -125,7 +127,7 @@ class Cache(object): self.name = name self.keylen = keylen - self.thread = None + self.thread = None # type: Optional[threading.Thread] self.metrics = register_cache( "cache", name, @@ -194,7 +196,7 @@ class Cache(object): callbacks = [callback] if callback else [] self.check_thread() observable = ObservableDeferred(value, consumeErrors=True) - observer = defer.maybeDeferred(observable.observe) + observer = observable.observe() entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) @@ -281,22 +283,15 @@ class Cache(object): def invalidate_all(self): self.check_thread() self.cache.clear() - for entry in itervalues(self._pending_deferred_cache): + for entry in self._pending_deferred_cache.values(): entry.invalidate() self._pending_deferred_cache.clear() -class _CacheDescriptorBase(object): - def __init__( - self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False - ): +class _CacheDescriptorBase: + def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -366,7 +361,7 @@ class CacheDescriptor(_CacheDescriptorBase): invalidated) by adding a special "cache_context" argument to the function and passing that as a kwarg to all caches called. For example:: - @cachedInlineCallbacks(cache_context=True) + @cached(cache_context=True) def foo(self, key, cache_context): r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) @@ -384,17 +379,11 @@ class CacheDescriptor(_CacheDescriptorBase): max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False, ): - super(CacheDescriptor, self).__init__( - orig, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - cache_context=cache_context, - ) + super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries self.tree = tree @@ -467,9 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase): observer = defer.succeed(cached_result_d) except KeyError: - ret = defer.maybeDeferred( - preserve_fn(self.function_to_call), obj, *args, **kwargs - ) + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) def onErr(f): cache.invalidate(cache_key) @@ -512,23 +499,17 @@ class CacheListDescriptor(_CacheDescriptorBase): of results. """ - def __init__( - self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False - ): + def __init__(self, orig, cached_method_name, list_name, num_args=None): """ Args: orig (function) - cached_method_name (str): The name of the chached method. + cached_method_name (str): The name of the cached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all named args of the function. - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks """ - super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks - ) + super().__init__(orig, num_args=num_args) self.list_name = list_name @@ -633,7 +614,7 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers.append( defer.maybeDeferred( - preserve_fn(self.function_to_call), **args_to_call + preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback) ) @@ -685,9 +666,13 @@ class _CacheContext: def cached( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, +) -> Callable[[F], _CachedFunction[F]]: + func = lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -696,22 +681,12 @@ def cached( iterable=iterable, ) - -def cachedInlineCallbacks( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - tree=tree, - inlineCallbacks=True, - cache_context=cache_context, - iterable=iterable, - ) + return cast(Callable[[F], _CachedFunction[F]], func) -def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): +def cachedList( + cached_method_name: str, list_name: str, num_args: Optional[int] = None +) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -721,18 +696,16 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal cache. Args: - cached_method_name (str): The name of the single-item lookup method. + cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name (str): The name of the argument that is the list to use to + list_name: The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache + num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. - inlineCallbacks (bool): Should the function be wrapped in an - `defer.inlineCallbacks`? Example: - class Example(object): + class Example: @cached(num_args=2) def do_something(self, first_arg): ... @@ -741,10 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal def batch_do_something(self, first_arg, second_args): ... """ - return lambda orig: CacheListDescriptor( + func = lambda orig: CacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, - inlineCallbacks=inlineCallbacks, ) + + return cast(Callable[[F], _CachedFunction[F]], func) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 6834e6f3ae..8592b93689 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -40,7 +40,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) -class DictionaryCache(object): +class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ @@ -53,7 +53,7 @@ class DictionaryCache(object): self.thread = None # caches_by_name[name] = self.cache - class Sentinel(object): + class Sentinel: __slots__ = [] self.sentinel = Sentinel() diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2726b67b6d..e15f7ee698 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -16,8 +16,6 @@ import logging from collections import OrderedDict -from six import iteritems, itervalues - from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -28,7 +26,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class ExpiringCache(object): +class ExpiringCache: def __init__( self, cache_name, @@ -150,7 +148,7 @@ class ExpiringCache(object): keys_to_delete = set() - for key, cache_entry in iteritems(self._cache): + for key, cache_entry in self._cache.items(): if now - cache_entry.time > self._expiry_ms: keys_to_delete.add(key) @@ -170,7 +168,7 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return sum(len(entry.value) for entry in itervalues(self._cache)) + return sum(len(entry.value) for entry in self._cache.values()) else: return len(self._cache) @@ -192,7 +190,7 @@ class ExpiringCache(object): return False -class _CacheEntry(object): +class _CacheEntry: __slots__ = ["time", "value"] def __init__(self, time, value): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index df4ea5901d..4bc1a67b58 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -30,7 +30,7 @@ def enumerate_leaves(node, depth): yield m -class _Node(object): +class _Node: __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] def __init__(self, prev_node, next_node, key, value, callbacks=set()): @@ -41,7 +41,7 @@ class _Node(object): self.callbacks = callbacks -class LruCache(object): +class LruCache: """ Least-recently-used cache. Supports del_multi only if cache_type=TreeCache diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a6c60888e5..df1a721add 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -23,7 +23,7 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -class ResponseCache(object): +class ResponseCache: """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 2a161bf244..c541bf4579 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -17,8 +17,6 @@ import logging import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union -from six import integer_types - from sortedcontainers import SortedDict from synapse.types import Collection @@ -88,7 +86,7 @@ class StreamChangeCache: def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) in integer_types + assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 2ea4e4e911..eb4d98f683 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,11 +1,9 @@ from typing import Dict -from six import itervalues - SENTINEL = object() -class TreeCache(object): +class TreeCache: """ Tree-based backing store for LruCache. Allows subtrees of data to be deleted efficiently. @@ -81,7 +79,7 @@ def iterate_tree_cache_entry(d): can contain dicts. """ if isinstance(d, dict): - for value_d in itervalues(d): + for value_d in d.values(): for value in iterate_tree_cache_entry(value_d): yield value else: @@ -91,7 +89,7 @@ def iterate_tree_cache_entry(d): yield d -class _Entry(object): +class _Entry: __slots__ = ["value"] def __init__(self, value): diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 6437aa907e..3e180cafd3 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) SENTINEL = object() -class TTLCache(object): +class TTLCache: """A key/value cache implementation where each entry has its own TTL""" def __init__(self, cache_name, timer=time.time): @@ -154,7 +154,7 @@ class TTLCache(object): @attr.s(frozen=True, slots=True) -class _CacheEntry(object): +class _CacheEntry: """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py new file mode 100644 index 0000000000..23393cf49b --- /dev/null +++ b/synapse/util/daemonize.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2012, 2013, 2014 Ilya Otyutskiy <ilya.otyutskiy@icloud.com> +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import fcntl +import logging +import os +import signal +import sys + + +def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: + """daemonize the current process + + This calls fork(), and has the main process exit. When it returns we will be + running in the child process. + """ + + # If pidfile already exists, we should read pid from there; to overwrite it, if + # locking will fail, because locking attempt somehow purges the file contents. + if os.path.isfile(pid_file): + with open(pid_file, "r") as pid_fh: + old_pid = pid_fh.read() + + # Create a lockfile so that only one instance of this daemon is running at any time. + try: + lock_fh = open(pid_file, "w") + except IOError: + print("Unable to create the pidfile.") + sys.exit(1) + + try: + # Try to get an exclusive lock on the file. This will fail if another process + # has the file locked. + fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB) + except IOError: + print("Unable to lock on the pidfile.") + # We need to overwrite the pidfile if we got here. + # + # XXX better to avoid overwriting it, surely. this looks racey as the pid file + # could be created between us trying to read it and us trying to lock it. + with open(pid_file, "w") as pid_fh: + pid_fh.write(old_pid) + sys.exit(1) + + # Fork, creating a new process for the child. + process_id = os.fork() + + if process_id != 0: + # parent process: exit. + + # we use os._exit to avoid running the atexit handlers. In particular, that + # means we don't flush the logs. This is important because if we are using + # a MemoryHandler, we could have logs buffered which are now buffered in both + # the main and the child process, so if we let the main process flush the logs, + # we'll get two copies. + os._exit(0) + + # This is the child process. Continue. + + # Stop listening for signals that the parent process receives. + # This is done by getting a new process id. + # setpgrp() is an alternative to setsid(). + # setsid puts the process in a new parent group and detaches its controlling + # terminal. + + os.setsid() + + # point stdin, stdout, stderr at /dev/null + devnull = "/dev/null" + if hasattr(os, "devnull"): + # Python has set os.devnull on this system, use it instead as it might be + # different than /dev/null. + devnull = os.devnull + + devnull_fd = os.open(devnull, os.O_RDWR) + os.dup2(devnull_fd, 0) + os.dup2(devnull_fd, 1) + os.dup2(devnull_fd, 2) + os.close(devnull_fd) + + # now that we have redirected stderr to /dev/null, any uncaught exceptions will + # get sent to /dev/null, so make sure we log them. + # + # (we don't normally expect reactor.run to raise any exceptions, but this will + # also catch any other uncaught exceptions before we get that far.) + + def excepthook(type_, value, traceback): + logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) + + sys.excepthook = excepthook + + # Set umask to default to safe file permissions when running as a root daemon. 027 + # is an octal number which we are typing as 0o27 for Python3 compatibility. + os.umask(0o27) + + # Change to a known directory. If this isn't done, starting a daemon in a + # subdirectory that needs to be deleted results in "directory busy" errors. + os.chdir(chdir) + + try: + lock_fh.write("%s" % (os.getpid())) + lock_fh.flush() + except IOError: + logger.error("Unable to write pid to the pidfile.") + print("Unable to write pid to the pidfile.") + sys.exit(1) + + # write a log line on SIGTERM. + def sigterm(signum, frame): + logger.warning("Caught signal %s. Stopping daemon." % signum) + sys.exit(0) + + signal.signal(signal.SIGTERM, sigterm) + + # Cleanup pid file at exit. + def exit(): + logger.warning("Stopping daemon.") + os.remove(pid_file) + sys.exit(0) + + atexit.register(exit) + + logger.warning("Starting daemon.") diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 45af8d3eeb..a750261e77 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging from twisted.internet import defer +from twisted.internet.defer import Deferred, fail, succeed +from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -32,14 +34,14 @@ def user_joined_room(distributor, user, room_id): distributor.fire("user_joined_room", user=user, room_id=room_id) -class Distributor(object): +class Distributor: """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. Signals are named simply by strings. TODO(paul): It would be nice to give signals stronger object identities, - so we can attach metadata, docstrings, detect typoes, etc... But this + so we can attach metadata, docstrings, detect typos, etc... But this model will do for today. """ @@ -79,7 +81,29 @@ class Distributor(object): run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -class Signal(object): +def maybeAwaitableDeferred(f, *args, **kw): + """ + Invoke a function that may or may not return a Deferred or an Awaitable. + + This is a modified version of twisted.internet.defer.maybeDeferred. + """ + try: + result = f(*args, **kw) + except Exception: + return fail(failure.Failure(captureVars=Deferred.debug)) + + if isinstance(result, Deferred): + return result + # Handle the additional case of an awaitable being returned. + elif inspect.isawaitable(result): + return defer.ensureDeferred(result) + elif isinstance(result, failure.Failure): + return fail(result) + else: + return succeed(result) + + +class Signal: """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -122,7 +146,7 @@ class Signal(object): ), ) - return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 8b17d1c8b8..733f5e26e6 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import queue +import queue from twisted.internet import threads from synapse.logging.context import make_deferred_yieldable, run_in_background -class BackgroundFileConsumer(object): +class BackgroundFileConsumer: """A consumer that writes to a file like object. Supports both push and pull producers diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 9815bb8667..0e445e01d7 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import binary_type, text_type - from canonicaljson import json from frozendict import frozendict @@ -26,7 +24,7 @@ def freeze(o): if isinstance(o, frozendict): return o - if isinstance(o, (binary_type, text_type)): + if isinstance(o, (bytes, str)): return o try: @@ -41,7 +39,7 @@ def unfreeze(o): if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, (binary_type, text_type)): + if isinstance(o, (bytes, str)): return o try: @@ -65,5 +63,8 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendicts without barfing -frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) +# A JSONEncoder which is capable of encoding frozendicts without barfing. +# Additionally reduce the whitespace produced by JSON encoding. +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, separators=(",", ":"), +) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index 6dce03dd3a..50516926f3 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -14,7 +14,7 @@ # limitations under the License. -class JsonEncodedObject(object): +class JsonEncodedObject: """ A common base class for defining protocol units that are represented as JSON. diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index ec61e14423..6e57c1ee72 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,14 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter -from twisted.internet import defer - from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge @@ -60,34 +58,42 @@ in_flight = InFlightGauge( sub_metrics=["real_time_max", "real_time_sum"], ) +T = TypeVar("T", bound=Callable[..., Any]) -def measure_func(name=None): - def wrapper(func): - block_name = func.__name__ if name is None else name - if inspect.iscoroutinefunction(func): +def measure_func(name: Optional[str] = None) -> Callable[[T], T]: + """ + Used to decorate an async function with a `Measure` context manager. + + Usage: + + @measure_func() + async def foo(...): + ... - @wraps(func) - async def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = await func(self, *args, **kwargs) - return r + Which is analogous to: - else: + async def foo(...): + with Measure(...): + ... + + """ + + def wrapper(func: T) -> T: + block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r - return measured_func + return cast(T, measured_func) return wrapper -class Measure(object): +class Measure: __slots__ = [ "clock", "name", diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 2605f3c65b..54c046b6e1 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -192,7 +192,7 @@ def _check_yield_points(f: Callable, changes: List[str]): result = yield d except Exception: # this will fish an earlier Failure out of the stack where possible, and - # thus is preferable to passing in an exeception to the Failure + # thus is preferable to passing in an exception to the Failure # constructor, since it results in less stack-mangling. result = Failure() diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index e5efdfcd02..70d11e1ec3 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -29,7 +29,7 @@ from synapse.logging.context import ( logger = logging.getLogger(__name__) -class FederationRateLimiter(object): +class FederationRateLimiter: def __init__(self, clock, config): """ Args: @@ -60,7 +60,7 @@ class FederationRateLimiter(object): return self.ratelimiters[host].ratelimit() -class _PerHostRatelimiter(object): +class _PerHostRatelimiter: def __init__(self, clock, config): """ Args: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index af69587196..79869aaa44 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -15,14 +15,12 @@ import logging import random -from twisted.internet import defer - import synapse.logging.context from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) -# the intial backoff, after the first transaction fails +# the initial backoff, after the first transaction fails MIN_RETRY_INTERVAL = 10 * 60 * 1000 # how much we multiply the backoff by after each subsequent fail @@ -54,8 +52,7 @@ class NotRetryingDestination(Exception): self.destination = destination -@defer.inlineCallbacks -def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): +async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): """For a given destination check if we have previously failed to send a request there and are waiting before retrying the destination. If we are not ready to retry the destination, this will raise a @@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) Example usage: try: - limiter = yield get_retry_limiter(destination, clock, store) + limiter = await get_retry_limiter(destination, clock, store) with limiter: - response = yield do_request() + response = await do_request() except NotRetryingDestination: # We aren't ready to retry that destination. raise @@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) failure_ts = None retry_last_ts, retry_interval = (0, 0) - retry_timings = yield store.get_destination_retry_timings(destination) + retry_timings = await store.get_destination_retry_timings(destination) if retry_timings: failure_ts = retry_timings["failure_ts"] @@ -117,7 +114,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs) ) -class RetryDestinationLimiter(object): +class RetryDestinationLimiter: def __init__( self, destination, @@ -174,7 +171,7 @@ class RetryDestinationLimiter(object): # has been decommissioned. # If we get a 401, then we should probably back off since they # won't accept our requests for at least a while. - # 429 is us being aggresively rate limited, so lets rate limit + # 429 is us being aggressively rate limited, so lets rate limit # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False @@ -222,10 +219,9 @@ class RetryDestinationLimiter(object): if self.failure_ts is None: self.failure_ts = retry_last_ts - @defer.inlineCallbacks - def store_retry_timings(): + async def store_retry_timings(): try: - yield self.store.set_destination_retry_timings( + await self.store.set_destination_retry_timings( self.destination, self.failure_ts, retry_last_ts, diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 08c86e92b8..61d96a6c28 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -17,16 +17,14 @@ import itertools import random import re import string -from collections import Iterable +from collections.abc import Iterable from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -# Note: The : character is allowed here for older clients, but will be removed in a -# future release. Context: https://github.com/matrix-org/synapse/issues/6766 -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$") +client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 20cf4c4a81..bd63b9107e 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -74,3 +74,26 @@ async def check_3pid_allowed(hs, medium, address): return True return False + + +def canonicalise_email(address: str) -> str: + """'Canonicalise' email address + Case folding of local part of email address and lowercase domain part + See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265 + + Args: + address: email address to be canonicalised + Returns: + The canonical form of the email address + Raises: + ValueError if the address could not be parsed. + """ + + address = address.strip() + + parts = address.split("@") + if len(parts) != 2: + logger.debug("Couldn't parse email address %s", address) + raise ValueError("Unable to parse email address") + + return parts[0].casefold() + "@" + parts[1].lower() diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 9bf6a44f75..be3b22469d 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import range - -class _Entry(object): +class _Entry: __slots__ = ["end_key", "queue"] def __init__(self, end_key): @@ -24,7 +22,7 @@ class _Entry(object): self.queue = [] -class WheelTimer(object): +class WheelTimer: """Stores arbitrary objects that will be returned after their timers have expired. """ diff --git a/synapse/visibility.py b/synapse/visibility.py index bab41182b9..e3da7744d2 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,11 +16,6 @@ import logging import operator -from six import iteritems, itervalues -from six.moves import map - -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage @@ -42,8 +37,7 @@ MEMBERSHIP_PRIORITY = ( ) -@defer.inlineCallbacks -def filter_events_for_client( +async def filter_events_for_client( storage: Storage, user_id, events, @@ -70,19 +64,19 @@ def filter_events_for_client( also be called to check whether a user can see the state at a given point. Returns: - Deferred[list[synapse.events.EventBase]] + list[synapse.events.EventBase] """ # Filter out events that have been soft failed so that we don't relay them # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield storage.state.get_state_for_events( + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -93,7 +87,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -102,7 +96,7 @@ def filter_events_for_client( for room_id in room_ids: retention_policies[ room_id - ] = yield storage.main.get_retention_policy_for_room(room_id) + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event): """ @@ -257,8 +251,7 @@ def filter_events_for_client( return list(filtered_events) -@defer.inlineCallbacks -def filter_events_for_server( +async def filter_events_for_server( storage: Storage, server_name, events, @@ -280,7 +273,7 @@ def filter_events_for_server( backfill or not. Returns - Deferred[list[FrozenEvent]] + list[FrozenEvent] """ def is_sender_erased(event, erased_senders): @@ -298,7 +291,7 @@ def filter_events_for_server( # membership states for the requesting server to determine # if the server is either in the room or has been invited # into the room. - for ev in itervalues(state): + for ev in state.values(): if ev.type != EventTypes.Member: continue try: @@ -322,9 +315,9 @@ def filter_events_for_server( return True # Lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If thats the case then we don't + # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -332,24 +325,24 @@ def filter_events_for_server( ) visibility_ids = set() - for sids in itervalues(event_to_state_ids): + for sids in event_to_state_ids.values(): hist = sids.get((EventTypes.RoomHistoryVisibility, "")) if hist: visibility_ids.add(hist) # If we failed to find any history visibility events then the default - # is "shared" visiblity. + # is "shared" visibility. if not visibility_ids: all_open = True else: - event_map = yield storage.main.get_events(visibility_ids) + event_map = await storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in itervalues(event_map) + for e in event_map.values() ) if not check_history_visibility_only: - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -378,7 +371,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -394,8 +387,8 @@ def filter_events_for_server( # event_id_to_state_key = { event_id: key - for key_to_eid in itervalues(event_to_state_ids) - for key, event_id in iteritems(key_to_eid) + for key_to_eid in event_to_state_ids.values() + for key, event_id in key_to_eid.items() } def include(typ, state_key): @@ -408,21 +401,17 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield storage.main.get_events( - [ - e_id - for e_id, key in iteritems(event_id_to_state_key) - if include(key[0], key[1]) - ] + event_map = await storage.main.get_events( + [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] ) event_to_state = { e_id: { key: event_map[inner_e_id] - for key, inner_e_id in iteritems(key_to_eid) + for key, inner_e_id in key_to_eid.items() if inner_e_id in event_map } - for e_id, key_to_eid in iteritems(event_to_state_ids) + for e_id, key_to_eid in event_to_state_ids.items() } to_return = [] diff --git a/synctl b/synctl index 960fd357ee..9395ebd048 100755 --- a/synctl +++ b/synctl @@ -26,8 +26,6 @@ import subprocess import sys import time -from six import iteritems - import yaml from synapse.config import find_config_files @@ -241,7 +239,8 @@ def main(): for config_file in config_files: with open(config_file) as file_stream: yaml_config = yaml.safe_load(file_stream) - config.update(yaml_config) + if yaml_config is not None: + config.update(yaml_config) pidfile = config["pid_file"] cache_factor = config.get("synctl_cache_factor") @@ -251,7 +250,7 @@ def main(): os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in iteritems(cache_factors): + for cache_name, factor in cache_factors.items(): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) worker_configfiles = [] @@ -362,7 +361,7 @@ def main(): if worker.cache_factor: os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - for cache_name, factor in iteritems(worker.cache_factors): + for cache_name, factor in worker.cache_factors.items(): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) if not start_worker(worker.app, configfile, worker.configfile): diff --git a/synmark/__init__.py b/synmark/__init__.py index afe4fad8cb..53698bd5ab 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None): stor = hs.get_datastore() # Run the database background updates. - if hasattr(stor.db.updates, "do_next_background_update"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + if hasattr(stor.db_pool.updates, "do_next_background_update"): + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) def cleanup(): for i in cleanup_tasks: diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 0bfb86bf1f..8ab56ec94c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -36,7 +36,7 @@ from tests import unittest from tests.utils import mock_getRawHeaders, setup_test_homeserver -class TestHandlers(object): +class TestHandlers: def __init__(self, hs): self.auth_handler = synapse.handlers.auth.AuthHandler(hs) @@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) + self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.is_support_user = Mock(return_value=defer.succeed(False)) @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" @@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") @@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + # This just needs to return a truth-y value. + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.failureResultOf(d, AuthError) @defer.inlineCallbacks def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org", "device_id": "device"} + return_value=defer.succeed( + {"name": "@baldrick:matrix.org", "device_id": "device"} + ) ) user_id = "@baldrick:matrix.org" @@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase): def get_user(tok): if token != tok: - return None - return { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + return defer.succeed(None) + return defer.succeed( + { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={"is_guest": False}) + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 4e67503cf0..d2d535d23c 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -369,14 +369,18 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -386,8 +390,10 @@ class FilteringTestCase(unittest.TestCase): def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart + "2", user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart + "2", user_filter=user_filter_json + ) ) event = MockEvent( event_id="$asdasd:localhost", @@ -396,8 +402,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart + "2", filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -406,14 +414,18 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events=events) @@ -422,16 +434,20 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events) @@ -457,16 +473,20 @@ class FilteringTestCase(unittest.TestCase): def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.filtering.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) self.assertEquals(filter_id, 0) self.assertEquals( user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 + yield defer.ensureDeferred( + self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) ) ), ) @@ -475,12 +495,16 @@ class FilteringTestCase(unittest.TestCase): def test_get_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) - filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index be20a89682..641093d349 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -30,6 +30,16 @@ class FrontendProxyTests(HomeserverTestCase): def default_config(self): c = super().default_config() c["worker_app"] = "synapse.app.frontend_proxy" + + c["worker_listeners"] = [ + { + "type": "http", + "port": 8080, + "bind_addresses": ["0.0.0.0"], + "resources": [{"names": ["client"]}], + } + ] + return c def test_listen_http_with_presence_enabled(self): @@ -39,14 +49,8 @@ class FrontendProxyTests(HomeserverTestCase): # Presence is on self.hs.config.use_presence = True - config = { - "port": 8080, - "bind_addresses": ["0.0.0.0"], - "resources": [{"names": ["client"]}], - } - # Listen with the config - self.hs._listen_http(config) + self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) @@ -67,14 +71,8 @@ class FrontendProxyTests(HomeserverTestCase): # Presence is off self.hs.config.use_presence = False - config = { - "port": 8080, - "bind_addresses": ["0.0.0.0"], - "resources": [{"names": ["client"]}], - } - # Listen with the config - self.hs._listen_http(config) + self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 7364f9f1ec..0f016c32eb 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -18,6 +18,7 @@ from parameterized import parameterized from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer +from synapse.config.server import parse_listener_def from tests.unittest import HomeserverTestCase @@ -35,6 +36,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. conf["worker_app"] = "yes" + return conf @parameterized.expand( @@ -53,12 +55,13 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): """ config = { "port": 8080, + "type": "http", "bind_addresses": ["0.0.0.0"], "resources": [{"names": names}], } # Listen with the config - self.hs._listen_http(config) + self.hs._listen_http(parse_listener_def(config)) # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] @@ -101,12 +104,13 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): """ config = { "port": 8080, + "type": "http", "bind_addresses": ["0.0.0.0"], "resources": [{"names": names}], } # Listen with the config - self.hs._listener_http(config, config) + self.hs._listener_http(self.hs.get_config(), parse_listener_def(config)) # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 4003869ed6..236b608d58 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_member_is_checked(self): @@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_match(self): @@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_no_match(self): @@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_alias_match(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#irc_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) def test_non_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( @@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#xmpp_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertFalse((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertFalse( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_regex_multiple_matches(self): @@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_barfoo:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_interested_in_self(self): @@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) - self.store.get_users_in_room.return_value = [ - "@alice:here", - "@irc_fo:here", # AS user - "@bob:here", - ] - self.store.get_aliases_for_room.return_value = [] + # Note that @irc_fo:here is the AS user. + self.store.get_users_in_room.return_value = defer.succeed( + ["@alice:here", "@irc_fo:here", "@bob:here"] + ) + self.store.get_aliases_for_room.return_value = defer.succeed([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( - (yield self.service.is_interested(event=self.event, store=self.store)) + ( + yield defer.ensureDeferred( + self.service.is_interested(event=self.event, store=self.store) + ) + ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 52f89d3f83..68a4caabbf 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -25,6 +25,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest +from tests.test_utils import make_awaitable from ..utils import MockClock @@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.get_appservice_state = Mock( return_value=defer.succeed(ApplicationServiceState.UP) ) - txn.send = Mock(return_value=defer.succeed(True)) + txn.send = Mock(return_value=make_awaitable(True)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): return_value=defer.succeed(ApplicationServiceState.UP) ) self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=defer.succeed(False)) # fails to send + txn.send = Mock(return_value=make_awaitable(False)) # fails to send self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events @@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=True) + txn.send = Mock(return_value=make_awaitable(True)) + txn.complete.return_value = make_awaitable(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=False) + txn.send = Mock(return_value=make_awaitable(False)) + txn.complete.return_value = make_awaitable(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=True) # successfully send the txn + txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count diff --git a/tests/config/test_base.py b/tests/config/test_base.py new file mode 100644 index 0000000000..42ee5f56d9 --- /dev/null +++ b/tests/config/test_base.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import tempfile + +from synapse.config import ConfigError +from synapse.util.stringutils import random_string + +from tests import unittest + + +class BaseConfigTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.hs = hs + + def test_loading_missing_templates(self): + # Use a temporary directory that exists on the system, but that isn't likely to + # contain template files + with tempfile.TemporaryDirectory() as tmp_dir: + # Attempt to load an HTML template from our custom template directory + template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + + # If no errors, we should've gotten the default template instead + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"error_description": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_custom_templates(self): + # Use a temporary directory that exists on the system + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a temporary bogus template file + with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template: + # Get temporary file's filename + template_filename = os.path.basename(tmp_template.name) + + # Write a custom HTML template + contents = b"{{ test_variable }}" + tmp_template.write(contents) + tmp_template.flush() + + # Attempt to load the template from our custom template directory + template = ( + self.hs.config.read_templates([template_filename], tmp_dir) + )[0] + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"test_variable": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_template_from_nonexistent_custom_directory(self): + with self.assertRaises(ConfigError): + self.hs.config.read_templates( + ["some_filename.html"], "a_nonexistent_directory" + ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 70c8e72303..2e6e7abf1f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -40,9 +40,10 @@ from synapse.logging.context import ( from synapse.storage.keys import FetchKeyResult from tests import unittest +from tests.test_utils import make_awaitable -class MockPerspectiveServer(object): +class MockPerspectiveServer: def __init__(self): self.server_name = "mock_server" self.key = signedjson.key.generate_signing_key(0) @@ -102,11 +103,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): } persp_deferred = defer.Deferred() - @defer.inlineCallbacks - def get_perspectives(**kwargs): + async def get_perspectives(**kwargs): self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): - yield persp_deferred + await persp_deferred return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -190,9 +190,9 @@ class KeyringTestCase(unittest.HomeserverTestCase): # should fail immediately on an unsigned object d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") - self.failureResultOf(d, SynapseError) + self.get_failure(d, SynapseError) - # should suceed on a signed object + # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") # self.assertFalse(d.called) self.get_success(d) @@ -202,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): with a null `ts_valid_until_ms` """ mock_fetcher = keyring.KeyFetcher() - mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) + mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) @@ -221,7 +221,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # should fail immediately on an unsigned object d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") - self.failureResultOf(d, SynapseError) + self.get_failure(d, SynapseError) # should fail on a signed object with a non-zero minimum_valid_until_ms, # as it tries to refetch the keys and fails. @@ -245,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key(1) - def get_keys(keys_to_fetch): + async def get_keys(keys_to_fetch): # there should only be one request object (with the max validity) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher = keyring.KeyFetcher() mock_fetcher.get_keys = Mock(side_effect=get_keys) @@ -282,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): """If the first fetcher cannot provide a recent enough key, we fall back""" key1 = signedjson.key.generate_signing_key(1) - def get_keys1(keys_to_fetch): + async def get_keys1(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) - } - } - ) + return { + "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} + } - def get_keys2(keys_to_fetch): + async def get_keys2(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher1 = keyring.KeyFetcher() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) @@ -355,7 +347,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): } signedjson.sign.sign_json(response, SERVER_NAME, testkey) - def get_json(destination, path, **kwargs): + async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) self.assertEqual(path, "/_matrix/key/v2/server/key1") return response @@ -444,7 +436,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect a perspectives-server key query """ - def post_json(destination, path, data, **kwargs): + async def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(path, "/_matrix/key/v2/query") @@ -580,14 +572,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # remove the perspectives server's signature response = build_response() del response["signatures"][self.mock_perspective_server.server_name] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") # remove the origin server's signature response = build_response() del response["signatures"][SERVER_NAME] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 640f5f3bce..3a80626224 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase): serialize/deserialize. """ - event, context = create_event( - self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + event, context = self.get_success( + create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) ) self._check_serialize_deserialize(event, context) @@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.test", - sender=self.user_id, - state_key="", + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) ) self._check_serialize_deserialize(event, context) @@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.room.member", - sender=self.user_id, - state_key=self.user_id, - content={"membership": "leave"}, + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) ) self._check_serialize_deserialize(event, context) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 0c9987be54..3d880c499d 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -15,14 +15,13 @@ from mock import Mock -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -59,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Artificially raise the complexity store = self.hs.get_datastore() - store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) + store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value request, channel = self.make_request( @@ -78,9 +77,44 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # Check whether an admin can join if option "admins_can_join" is undefined, + # this option defaults to false, so the join should fail. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) + handler.federation_handler.do_invite_join = Mock( + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) d = handler._remote_join( @@ -116,13 +150,15 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable(None) + ) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( + self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable( 600 ) @@ -141,3 +177,85 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): f = self.get_failure(d, SynapseError) self.assertEqual(f.value.code, 400) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + +class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): + # Test the behavior of joining rooms which exceed the complexity if option + # limit_remote_rooms.admins_can_join is True. + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["limit_remote_rooms"] = { + "enabled": True, + "complexity": 0.05, + "admins_can_join": True, + } + return config + + def test_join_too_large_no_admin(self): + # A user which is not an admin should not be able to join a remote room + # which is too complex. + + u1 = self.register_user("u1", "pass") + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) + handler.federation_handler.do_invite_join = Mock( + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # An admin should be able to join rooms where a complexity check fails. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock( + side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) + ) + handler.federation_handler.do_invite_join = Mock( + side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request success since the user is an admin + self.get_success(d) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index ff12539041..5f512ff8bf 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -21,35 +21,39 @@ from signedjson.types import BaseKey, SigningKey from twisted.internet import defer +from synapse.api.constants import RoomEncryptionAlgorithms from synapse.rest import admin from synapse.rest.client.v1 import login from synapse.types import JsonDict, ReadReceipt +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): + mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) + # Ensure a new Awaitable is created for each call. + mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + ["test", "host2"] + ) return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), + state_handler=mock_state_handler, federation_transport_client=Mock(spec=["send_transaction"]), ) @override_config({"send_federation": True}) def test_send_receipts(self): - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -80,19 +84,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -124,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): receipt = ReadReceipt( "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() mock_send_transaction.assert_not_called() @@ -163,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -173,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out get_current_hosts_in_room - mock_state_handler = hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - # stub out get_users_who_share_room_with_user so that it claims that # `@user2:host2` is in the room def get_users_who_share_room_with_user(user_id): @@ -536,7 +532,10 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey): return { "user_id": user_id, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], "keys": { "curve25519:" + device_id: "curve25519+key", key_id(sk): encode_pubkey(sk), diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 296dc887be..da933ecd75 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -15,6 +15,8 @@ # limitations under the License. import logging +from parameterized import parameterized + from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin @@ -23,6 +25,37 @@ from synapse.rest.client.v1 import login, room from tests import unittest +class FederationServerTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)]) + def test_bad_request(self, query_content): + """ + Querying with bad data returns a reasonable error code. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + "/get_missing_events/(?P<room_id>[^/]*)/?" + + request, channel = self.make_request( + "POST", + "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), + query_content, + ) + self.render(request) + self.assertEquals(400, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") + + class ServerACLsTestCase(unittest.TestCase): def test_blacklisted_server(self): e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 27d83bb7d9..72e22d655f 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -26,7 +26,7 @@ from tests.unittest import override_config class RoomDirectoryFederationTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - class Authenticator(object): + class Authenticator: def authenticate_request(self, request, content): return defer.succeed("otherserver.nottld") diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ba7148ec01..2a0b7c1b56 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from tests.test_utils import make_awaitable from tests.utils import MockClock from .. import unittest @@ -32,10 +33,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api = Mock() self.mock_scheduler = Mock() hs = Mock() - hs.get_datastore = Mock(return_value=self.mock_store) - self.mock_store.get_received_ts.return_value = 0 - hs.get_application_service_api = Mock(return_value=self.mock_as_api) - hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) + hs.get_datastore.return_value = self.mock_store + self.mock_store.get_received_ts.return_value = defer.succeed(0) + self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None) + hs.get_application_service_api.return_value = self.mock_as_api + hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) @@ -48,18 +50,18 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice(is_interested=False), ] - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value=[]) + self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_store.get_app_services.return_value = services + self.mock_store.get_user_by_id.return_value = defer.succeed([]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ - (0, [event]), - (0, []), + defer.succeed((0, [event])), + defer.succeed((0, [])), ] - self.mock_as_api.push = Mock() - yield self.handler.notify_interested_services(0) + yield defer.ensureDeferred(self.handler.notify_interested_services(0)) self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) @@ -68,36 +70,34 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] - services[0].is_interested_in_user = Mock(return_value=True) - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value=None) + services[0].is_interested_in_user.return_value = True + self.mock_store.get_app_services.return_value = services + self.mock_store.get_user_by_id.return_value = defer.succeed(None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.push = Mock() - self.mock_as_api.query_user = Mock() + self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - (0, [event]), - (0, []), + defer.succeed((0, [event])), + defer.succeed((0, [])), ] - yield self.handler.notify_interested_services(0) + yield defer.ensureDeferred(self.handler.notify_interested_services(0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] - services[0].is_interested_in_user = Mock(return_value=True) - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value={"name": user_id}) + services[0].is_interested_in_user.return_value = True + self.mock_store.get_app_services.return_value = services + self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.push = Mock() - self.mock_as_api.query_user = Mock() + self.mock_as_api.query_user.return_value = defer.succeed(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - (0, [event]), - (0, []), + defer.succeed((0, [event])), + defer.succeed((0, [])), ] - yield self.handler.notify_interested_services(0) + yield defer.ensureDeferred(self.handler.notify_interested_services(0)) self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", @@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_query_room_alias_exists(self): room_alias_str = "#foo:bar" room_alias = Mock() - room_alias.to_string = Mock(return_value=room_alias_str) + room_alias.to_string.return_value = room_alias_str room_id = "!alpha:bet" servers = ["aperture"] @@ -118,12 +118,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_interested_in_alias=False), ] - self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_association_from_room_alias = Mock( - return_value=Mock(room_id=room_id, servers=servers) + self.mock_as_api.query_alias.return_value = make_awaitable(True) + self.mock_store.get_app_services.return_value = services + self.mock_store.get_association_from_room_alias.return_value = make_awaitable( + Mock(room_id=room_id, servers=servers) ) - result = yield self.handler.query_room_alias_exists(room_alias) + result = yield defer.ensureDeferred( + self.handler.query_room_alias_exists(room_alias) + ) self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str @@ -133,14 +136,14 @@ class AppServiceHandlerTestCase(unittest.TestCase): def _mkservice(self, is_interested): service = Mock() - service.is_interested = Mock(return_value=is_interested) + service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" return service def _mkservice_alias(self, is_interested_in_alias): service = Mock() - service.is_interested_in_alias = Mock(return_value=is_interested_in_alias) + service.is_interested_in_alias.return_value = is_interested_in_alias service.token = "mock_service_token" service.url = "mock_service_url" return service diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c01b04e1dc..c7efd3822d 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -24,10 +24,11 @@ from synapse.api.errors import ResourceLimitError from synapse.handlers.auth import AuthHandler from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class AuthHandlers(object): +class AuthHandlers: def __init__(self, hs): self.auth_handler = AuthHandler(hs) @@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.large_number_of_users) + side_effect=lambda: make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): @@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.large_number_of_users) + side_effect=lambda: make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase): # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase): ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase): ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.small_number_of_users) + side_effect=lambda: make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception yield defer.ensureDeferred( @@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.small_number_of_users) + side_effect=lambda: make_awaitable(self.small_number_of_users) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 62b47f6574..6aa322bf3a 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - res = self.handler.get_device(user1, "abc") - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError ) # we'd like to check the access token was invalidated, but that's a @@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): def test_update_unknown_device(self): update = {"display_name": "new_display"} - res = self.handler.update_device("user_id", "unknown_device_id", update) - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.update_device("user_id", "unknown_device_id", update), + synapse.api.errors.NotFoundError, ) def _record_users(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 00bb776271..bc0c5aefdc 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -16,8 +16,6 @@ from mock import Mock -from twisted.internet import defer - import synapse import synapse.api.errors from synapse.api.constants import EventTypes @@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): @@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index e1e144b2e7..210ddcbb88 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,17 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import mock -import signedjson.key as key -import signedjson.sign as sign +from signedjson import key as key, sign as sign from twisted.internet import defer import synapse.handlers.e2e_keys import synapse.storage from synapse.api import errors +from synapse.api.constants import RoomEncryptionAlgorithms from tests import unittest, utils @@ -47,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -61,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @@ -85,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ) ) self.fail("No error when changing string key") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ) ) self.fail("No error when replacing dict key with string") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ) ) self.fail("No error when replacing string key with dict") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ) ) self.fail("No error when replacing dict key") except errors.SynapseError: @@ -134,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) ) self.assertEqual( res2, @@ -164,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) keys2 = { "master_key": { @@ -176,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys2) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys2) + ) - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @@ -216,13 +241,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { "user_id": local_user, "device_id": "abc", - "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], "keys": { "ed25519:abc": "base64+ed25519+key", "curve25519:abc": "base64+curve25519+key", @@ -232,7 +262,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_2 = { "user_id": local_user, "device_id": "def", - "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], "keys": { "ed25519:def": "base64+ed25519+key", "curve25519:def": "base64+curve25519+key", @@ -240,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield self.handler.upload_keys_for_user( - local_user, "abc", {"device_keys": device_key_1} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) ) - yield self.handler.upload_keys_for_user( - local_user, "def", {"device_keys": device_key_2} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) ) # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) ) # sign the second device key and upload both device keys. The server @@ -259,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) ) device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] @@ -287,20 +328,26 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) res = None try: - yield self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", + yield defer.ensureDeferred( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -315,7 +362,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key = { "user_id": local_user, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, "signatures": {local_user: {"ed25519:xyz": "something"}}, } @@ -323,8 +373,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield self.handler.upload_keys_for_user( - local_user, device_id, {"device_keys": device_key} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) ) # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 @@ -364,7 +416,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + ) # set up another user with a master key. This user will be signed by # the first user @@ -376,76 +430,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user( - other_user, {"master_key": other_master_key} + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) ) # test various signature failures (see below) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: { - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - device_id: { - "user_id": local_user, - "device_id": device_id, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2", - ], - "keys": { - "curve25519:xyz": "curve25519+key", - # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because device is unknown - # should fail with NOT_FOUND - "unknown": { - "user_id": local_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, }, }, - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - master_pubkey: { - "user_id": local_user, - "usage": ["master"], - "keys": {"ed25519:" + master_pubkey: master_pubkey}, - "signatures": { - local_user: {"ed25519:" + device_pubkey: "something"} + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, - }, - }, - other_user: { - # fails because the device is not the user's master-signing key - # should fail with NOT_FOUND - "unknown": { - "user_id": other_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} - }, - }, - other_master_pubkey: { - # fails because the key doesn't match what the server has - # should fail with UNKNOWN - "user_id": other_user, - "usage": ["master"], - "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, - "something": "random", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, }, }, - }, + ) ) user_failures = ret["failures"][local_user] @@ -470,19 +538,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: {device_id: device_key, master_pubkey: master_key}, - other_user: {other_master_pubkey: other_master_key}, - }, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ret = yield defer.ensureDeferred( + self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) ) self.assertEqual( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 70f172eb02..3362050ce0 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user) + yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -87,15 +89,21 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] + self.assertIsInstance(version_etag, str) del res["etag"] self.assertDictEqual( res, @@ -108,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(self.local_user, "1") + res = yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -122,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -148,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version, - }, + res = yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -184,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -201,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) ) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -233,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -260,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.delete_version(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -271,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user) + yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -280,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) # check that it's gone res = None try: - yield self.handler.get_version_info(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -303,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_room_keys(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -312,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, @@ -330,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys( - self.local_user, "no_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys) ) except errors.SynapseError as e: res = e.code @@ -342,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) ) except errors.SynapseError as e: res = e.code @@ -361,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - version = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(version, "2") res = None try: - yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "1", room_keys) + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 403) @@ -387,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, room_keys) @@ -414,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) # get the etag to compare to future versions - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -433,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -463,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here @@ -480,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") # check for bulk-delete - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys(self.local_user, version) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys(self.local_user, version) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..89ec5fcb31 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -75,7 +75,17 @@ COMMON_CONFIG = { COOKIE_NAME = b"oidc_session" COOKIE_PATH = "/_synapse/oidc" -MockedMappingProvider = Mock(OidcMappingProvider) + +class TestMappingProvider(OidcMappingProvider): + @staticmethod + def parse_config(config): + return + + def get_remote_user_id(self, userinfo): + return userinfo["sub"] + + async def map_user_attributes(self, userinfo, token): + return {"localpart": userinfo["username"], "display_name": None} def simple_async_mock(return_value=None, raises=None): @@ -123,7 +133,7 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config["issuer"] = ISSUER oidc_config["scopes"] = SCOPES oidc_config["user_mapping_provider"] = { - "module": __name__ + ".MockedMappingProvider" + "module": __name__ + ".TestMappingProvider", } config["oidc_config"] = oidc_config @@ -374,12 +384,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +406,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +417,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +451,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() @@ -568,3 +590,30 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(OidcError) as exc: yield defer.ensureDeferred(self.handler._exchange_code(code)) self.assertEqual(exc.exception.error, "some_error") + + def test_map_userinfo_to_user(self): + """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + userinfo = { + "sub": "test_user", + "username": "test_user", + } + # The token doesn't matter with the default user mapping provider. + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Some providers return an integer ID. + userinfo = { + "sub": 1234, + "username": "test_user_2", + } + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_2:test") diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 05ea40a7de..306dcfe944 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,6 +19,7 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( @@ -32,7 +33,6 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest.client.v1 import room -from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from tests import unittest diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 29dd7d9c6e..8e95e53d9e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -24,10 +24,11 @@ from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class ProfileHandlers(object): +class ProfileHandlers: def __init__(self, hs): self.profile_handler = MasterProfileHandler(hs) @@ -63,16 +64,20 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) - displayname = yield self.handler.get_displayname(self.frank) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.frank) + ) self.assertEquals("Frank", displayname) @@ -101,7 +106,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) @defer.inlineCallbacks @@ -109,10 +119,17 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) # Setting displayname a second time is forbidden @@ -136,11 +153,13 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_other_name(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) - displayname = yield self.handler.get_displayname(self.alice) + displayname = yield defer.ensureDeferred( + self.handler.get_displayname(self.alice) + ) self.assertEquals(displayname, "Alice") self.mock_federation.make_query.assert_called_with( @@ -152,22 +171,27 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") - yield self.store.set_profile_displayname("caroline", "Caroline") + yield defer.ensureDeferred(self.store.create_profile("caroline")) + yield defer.ensureDeferred( + self.store.set_profile_displayname("caroline", "Caroline") + ) - response = yield self.query_handlers["profile"]( - {"user_id": "@caroline:test", "field": "displayname"} + response = yield defer.ensureDeferred( + self.query_handlers["profile"]( + {"user_id": "@caroline:test", "field": "displayname"} + ) ) self.assertEquals({"displayname": "Caroline"}, response) @defer.inlineCallbacks def test_get_my_avatar(self): - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) - - avatar_url = yield self.handler.get_avatar_url(self.frank) + avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) self.assertEquals("http://my.server/me.png", avatar_url) @@ -182,7 +206,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/pic.gif", ) @@ -196,7 +224,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) @@ -205,12 +237,18 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index ca32f993a3..eddf5e2498 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -15,17 +15,21 @@ from mock import Mock -from twisted.internet import defer - +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester +from tests.test_utils import make_awaitable +from tests.unittest import override_config +from tests.utils import mock_getRawHeaders + from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -96,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value - 1) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @@ -104,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.lots_of_users) + side_effect=lambda: make_awaitable(self.lots_of_users) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -112,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -122,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.lots_of_users) + side_effect=lambda: make_awaitable(self.lots_of_users) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError @@ -145,9 +149,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -185,7 +189,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=defer.succeed(False)) + self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -193,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) + @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): room_alias_str = "#room:test" - self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=defer.succeed(1)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -212,12 +216,218 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=defer.succeed(2)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_rooms_federated": False, + } + ) + def test_auto_create_auto_join_rooms_federated(self): + """ + Auto-created rooms that are private require an invite to go to the user + (instead of directly joining it). + """ + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly not federated. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertFalse(room["federatable"]) + self.assertFalse(room["public"]) + self.assertEqual(room["join_rules"], "public") + self.assertIsNone(room["guest_access"]) + + # The user should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"} + ) + def test_auto_join_mxid_localpart(self): + """ + Ensure the user still needs up in the room created by a different user. + """ + # Ensure the support user exists. + inviter = "@support:test" + + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly a public room. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertEqual(room["join_rules"], "public") + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + # Register a second user, which should also end up in the room. + user_id = self.get_success(self.handler.register_user(localpart="bob")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_room_preset": "private_chat", + "auto_join_mxid_localpart": "support", + } + ) + def test_auto_create_auto_join_room_preset(self): + """ + Auto-created rooms that are private require an invite to go to the user + (instead of directly joining it). + """ + # Ensure the support user exists. + inviter = "@support:test" + + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly a private room. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertFalse(room["public"]) + self.assertEqual(room["join_rules"], "invite") + self.assertEqual(room["guest_access"], "can_join") + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + # Register a second user, which should also end up in the room. + user_id = self.get_success(self.handler.register_user(localpart="bob")) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_room_preset": "private_chat", + "auto_join_mxid_localpart": "support", + } + ) + def test_auto_create_auto_join_room_preset_guest(self): + """ + Auto-created rooms that are private require an invite to go to the user + (instead of directly joining it). + + This should also work for guests. + """ + inviter = "@support:test" + + room_alias_str = "#room:test" + user_id = self.get_success( + self.handler.register_user(localpart="jeff", make_guest=True) + ) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room is properly a private room. + room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + self.assertFalse(room["public"]) + self.assertEqual(room["join_rules"], "invite") + self.assertEqual(room["guest_access"], "can_join") + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + @override_config( + { + "auto_join_rooms": ["#room:test"], + "autocreate_auto_join_room_preset": "private_chat", + "auto_join_mxid_localpart": "support", + } + ) + def test_auto_create_auto_join_room_preset_invalid_permissions(self): + """ + Auto-created rooms that are private require an invite, check that + registration doesn't completely break if the inviter doesn't have proper + permissions. + """ + inviter = "@support:test" + + # Register an initial user to create the room and such (essentially this + # is a subset of test_auto_create_auto_join_room_preset). + room_alias_str = "#room:test" + user_id = self.get_success(self.handler.register_user(localpart="jeff")) + + # Ensure the room was created. + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = self.get_success(directory_handler.get_association(room_alias)) + + # Ensure the room exists. + self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + + # Both users should be in the room. + rooms = self.get_success(self.store.get_rooms_for_user(inviter)) + self.assertIn(room_id["room_id"], rooms) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertIn(room_id["room_id"], rooms) + + # Lower the permissions of the inviter. + event_creation_handler = self.hs.get_event_creation_handler() + requester = create_requester(inviter) + event, context = self.get_success( + event_creation_handler.create_event( + requester, + { + "type": "m.room.power_levels", + "state_key": "", + "room_id": room_id["room_id"], + "content": {"invite": 100, "users": {inviter: 0}}, + "sender": inviter, + }, + ) + ) + self.get_success( + event_creation_handler.send_nonmember_event(requester, event, context) + ) + + # Register a second user, which won't be be in the room (or even have an invite) + # since the inviter no longer has the proper permissions. + user_id = self.get_success(self.handler.register_user(localpart="bob")) + + # This user should not be in any rooms. + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(user_id) + ) + self.assertEqual(rooms, set()) + self.assertEqual(invited_rooms, []) + def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -266,6 +476,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index d9d312f0fb..a609f148c0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -15,7 +15,7 @@ from synapse.rest import admin from synapse.rest.client.v1 import login, room -from synapse.storage.data_stores.main import stats +from synapse.storage.databases.main import stats from tests import unittest @@ -42,36 +42,36 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms", + "update_name": "populate_stats_process_rooms_2", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", + "depends_on": "populate_stats_process_rooms_2", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def get_all_room_state(self): - return self.store.db.simple_select_list( + async def get_all_room_state(self): + return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store.db.simple_select_one( + self.store.db_pool.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_initial_room(self): @@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r = self.get_success(self.get_all_room_state()) @@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_update_one( + self.store.db_pool.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now, before the table is actually ingested, add some more events. @@ -217,28 +217,31 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", - {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, + { + "update_name": "populate_stats_process_rooms_2", + "progress_json": "{}", + }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", + "depends_on": "populate_stats_process_rooms_2", }, ) ) - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) self.reactor.advance(86401) @@ -253,7 +256,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # self.handler.notify_new_event() # We need to let the delta processor advance… - self.pump(10 * 60) + self.reactor.advance(10 * 60) # Get the slices! There should be two -- day 1, and day 2. r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) @@ -346,6 +349,37 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + def test_updating_profile_information_does_not_increase_joined_members_count(self): + """ + Check that the joined_members count does not increase when a user changes their + profile information (which is done by sending another join membership event into + the room. + """ + self._perform_background_initial_update() + + # Create a user and room + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + # Get the current room stats + r1stats_ante = self._get_current_stats("room", r1) + + # Send a profile update into the room + new_profile = {"displayname": "bob"} + self.helper.change_membership( + r1, u1, u1, "join", extra_data=new_profile, tok=u1token + ) + + # Get the new room stats + r1stats_post = self._get_current_stats("room", r1) + + # Ensure that the user count did not changed + self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"]) + self.assertEqual( + r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"] + ) + def test_send_state_event_nonoverwriting(self): """ When we send a non-overwriting state event, it increments total_events AND current_state_events @@ -669,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -689,29 +723,29 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms", + "update_name": "populate_stats_process_rooms_2", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", + "depends_on": "populate_stats_process_rooms_2", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -722,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r1stats_complete = self._get_current_stats("room", r1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2fa8d4739b..7bf15c4ba9 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -21,9 +21,10 @@ from mock import ANY, Mock, call from twisted.internet import defer from synapse.api.errors import AuthError -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import register_federation_servlets @@ -115,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( (0, []) ) @@ -126,9 +127,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + return None hs.get_auth().check_user_in_room = check_user_in_room @@ -137,24 +139,26 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - def get_current_users_in_room(room_id): - return {str(u) for u in self.room_members} + def get_users_in_room(room_id): + return defer.succeed({str(u) for u in self.room_members}) - hs.get_state_handler().get_current_users_in_room = get_current_users_in_room + self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.return_value = ( + self.datastore.get_user_directory_stream_pos.side_effect = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam - defer.succeed(1) + lambda: make_awaitable(1) ) self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) - self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( + None + ) + self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) @@ -163,9 +167,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf( + self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -190,9 +197,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] - self.successResultOf( + self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -265,9 +275,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf( + self.get_success( self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, ) ) @@ -305,9 +317,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf( + self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) @@ -344,9 +359,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # SYN-230 - see if we can still set after timeout - self.successResultOf( + self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c15bce5bef..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -17,12 +17,13 @@ from mock import Mock from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import UserTypes +from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest +from tests.unittest import override_config class UserDirectoryTestCase(unittest.HomeserverTestCase): @@ -147,9 +148,97 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + @override_config({"encryption_enabled_by_default_for_room_type": "all"}) + def test_encrypted_by_default_config_option_all(self): + """Tests that invite-only and non-invite-only rooms have encryption enabled by + default when the config option encryption_enabled_by_default_for_room_type is "all". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) + def test_encrypted_by_default_config_option_invite(self): + """Tests that only new, invite-only rooms have encryption enabled by default when + the config option encryption_enabled_by_default_for_room_type is "invite". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + + @override_config({"encryption_enabled_by_default_for_room_type": "off"}) + def test_encrypted_by_default_config_option_off(self): + """Tests that neither new invite-only nor non-invite-only rooms have encryption + enabled by default when the config option + encryption_enabled_by_default_for_room_type is "off". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -180,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -193,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True @@ -250,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -261,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -273,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -285,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -295,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -305,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -348,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() @@ -387,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2096ba3c91..5d41443293 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -133,7 +133,7 @@ def create_test_cert_file(sanlist): @implementer(IOpenSSLServerConnectionCreator) -class TestServerTLSConnectionFactory(object): +class TestServerTLSConnectionFactory: """An SSL connection creator which returns connections which present a certificate signed by our test CA.""" diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 562397cdda..8b5ad4574f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -67,6 +67,14 @@ def get_connection_factory(): return test_server_connection_factory +# Once Async Mocks or lambdas are supported this can go away. +def generate_resolve_service(result): + async def resolve_service(_): + return result + + return resolve_service + + class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -86,6 +94,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), + b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ) @@ -93,6 +102,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -186,6 +196,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) self._send_well_known_response(request, content, headers=response_headers) return well_known_server @@ -231,6 +244,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) content = request.content.read() self.assertEqual(content, b"") @@ -365,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") @@ -448,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has no port, no SRV, and no well-known """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -502,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known delegates elsewhere """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -564,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -653,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -709,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -719,10 +735,12 @@ class MatrixFederationAgentTests(unittest.TestCase): agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, + user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=tls_factory), + b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ), @@ -754,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -809,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self._handle_well_known_connection( client_factory, @@ -851,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -912,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"xn--trget-3qa.com", port=8443) # târget.com - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com + ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") @@ -954,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -977,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -985,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1008,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1034,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # another lookup. self.reactor.pump((900.0,)) - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 @@ -1064,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1077,11 +1107,12 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ - - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + ) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -1233,7 +1264,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) -class TrustingTLSPolicyForHTTPS(object): +class TrustingTLSPolicyForHTTPS: """An IPolicyForHTTPS which checks that the certificate belongs to the right server, but doesn't check the certificate chain.""" diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index babc201643..fee2985d35 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase): with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - - self.assertNoResult(resolve_d) - - # should have reset to the sentinel context - self.assertIs(current_context(), SENTINEL_CONTEXT) - - result = yield resolve_d + result = yield defer.ensureDeferred(resolve_d) # should have restored our context self.assertIs(current_context(), ctx) @@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertFalse(dns_client_mock.lookupService.called) @@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): - yield resolver.resolve_service(service_name) + yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks def test_name_error(self): @@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) @@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in failureResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) # returning a single "." should make the lookup fail with a ConenctError lookup_deferred.callback( @@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in successResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) lookup_deferred.callback( ( diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py new file mode 100644 index 0000000000..62d36c2906 --- /dev/null +++ b/tests/http/test_additional_resource.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.http.additional_resource import AdditionalResource +from synapse.http.server import respond_with_json + +from tests.unittest import HomeserverTestCase + + +class _AsyncTestCustomEndpoint: + def __init__(self, config, module_api): + pass + + async def handle_request(self, request): + respond_with_json(request, 200, {"some_key": "some_value_async"}) + + +class _SyncTestCustomEndpoint: + def __init__(self, config, module_api): + pass + + async def handle_request(self, request): + respond_with_json(request, 200, {"some_key": "some_value_sync"}) + + +class AdditionalResourceTests(HomeserverTestCase): + """Very basic tests that `AdditionalResource` works correctly with sync + and async handlers. + """ + + def test_async(self): + handler = _AsyncTestCustomEndpoint({}, None).handle_request + self.resource = AdditionalResource(self.hs, handler) + + request, channel = self.make_request("GET", "/") + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) + + def test_sync(self): + handler = _SyncTestCustomEndpoint({}, None).handle_request + self.resource = AdditionalResource(self.hs, handler) + + request, channel = self.make_request("GET", "/") + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index fff4f0cbf4..5604af3795 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -16,6 +16,7 @@ from mock import Mock from netaddr import IPSet +from parameterized import parameterized from twisted.internet import defer from twisted.internet.defer import TimeoutError @@ -58,7 +59,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +123,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +133,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +161,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +193,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +237,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +255,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +276,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +299,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +323,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +357,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +412,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +451,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +476,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() @@ -486,6 +509,53 @@ class FederationClientTests(HomeserverTestCase): self.assertFalse(conn.disconnecting) # wait for a while - self.pump(120) + self.reactor.advance(120) self.assertTrue(conn.disconnecting) + + @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)]) + def test_json_error(self, return_value): + """ + Test what happens if invalid JSON is returned from the remote endpoint. + """ + + test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + + # complete the connection and wire it up to a fake transport + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # that should have made it send the request to the transport + self.assertRegex(transport.value(), b"^GET /foo/bar") + self.assertRegex(transport.value(), b"Host: testserv:8008") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # Send it the HTTP response + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: %i\r\n" + b"\r\n" + b"%s" % (len(return_value), return_value) + ) + + self.pump() + + f = self.failureResultOf(test_d) + self.assertIsInstance(f.value, ValueError) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py new file mode 100644 index 0000000000..45089158ce --- /dev/null +++ b/tests/http/test_servlet.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from io import BytesIO + +from mock import Mock + +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + parse_json_object_from_request, + parse_json_value_from_request, +) + +from tests import unittest + + +def make_request(content): + """Make an object that acts enough like a request.""" + request = Mock(spec=["content"]) + + if isinstance(content, dict): + content = json.dumps(content).encode("utf8") + + request.content = BytesIO(content) + return request + + +class TestServletUtils(unittest.TestCase): + def test_parse_json_value(self): + """Basic tests for parse_json_value_from_request.""" + # Test round-tripping. + obj = {"foo": 1} + result = parse_json_value_from_request(make_request(obj)) + self.assertEqual(result, obj) + + # Results don't have to be objects. + result = parse_json_value_from_request(make_request(b'["foo"]')) + self.assertEqual(result, ["foo"]) + + # Test empty. + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b"")) + + result = parse_json_value_from_request(make_request(b""), allow_empty_body=True) + self.assertIsNone(result) + + # Invalid UTF-8. + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b"\xFF\x00")) + + # Invalid JSON. + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b"foo")) + + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b'{"foo": Infinity}')) + + def test_parse_json_object(self): + """Basic tests for parse_json_object_from_request.""" + # Test empty. + result = parse_json_object_from_request( + make_request(b""), allow_empty_body=True + ) + self.assertEqual(result, {}) + + # Test not an object + with self.assertRaises(SynapseError): + parse_json_object_from_request(make_request(b'["foo"]')) diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py index 451d05c0f0..d36f5f426c 100644 --- a/tests/logging/test_structured.py +++ b/tests/logging/test_structured.py @@ -29,12 +29,12 @@ from synapse.logging.context import LoggingContext from tests.unittest import DEBUG, HomeserverTestCase -class FakeBeginner(object): +class FakeBeginner: def beginLoggingTo(self, observers, **kwargs): self.observers = observers -class StructuredLoggingTestBase(object): +class StructuredLoggingTestBase: """ Test base that registers a cleanup handler to reset the stdlib log handler to 'unset'. diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 807cd65dd6..04de0b9dbe 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the new user exists with all provided attributes self.assertEqual(user_id, "@bob:test") self.assertTrue(access_token) - self.assertTrue(self.store.get_user_by_id(user_id)) + self.assertTrue(self.get_success(self.store.get_user_by_id(user_id))) # Check that the email was assigned emails = self.get_success(self.store.user_get_threepids(user_id)) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 83032cc9ea..3224568640 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase @attr.s -class _User(object): +class _User: "Helper wrapper for user ID and access token" id = attr.ib() token = attr.ib() @@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase): last_stream_ordering = pushers[0]["last_stream_ordering"] # Advance time a bit, so the pusher will register something has happened - self.pump(100) + self.pump(10) # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( diff --git a/tests/push/test_http.py b/tests/push/test_http.py index baf9c785f4..b567868b02 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -25,7 +25,6 @@ from tests.unittest import HomeserverTestCase class HTTPPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -35,7 +34,6 @@ class HTTPPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor, clock): - self.push_attempts = [] m = Mock() @@ -90,9 +88,6 @@ class HTTPPusherTests(HomeserverTestCase): # Create a room room = self.helper.create_room_as(user_id, tok=access_token) - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - # The other user joins self.helper.join(room=room, user=other_user_id, tok=other_access_token) @@ -157,3 +152,350 @@ class HTTPPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + + def test_sends_high_priority_for_encrypted(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to an encrypted message. + This will happen both in 1:1 rooms and larger rooms. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send an encrypted event + # I know there'd normally be set-up of an encrypted room first + # but this will do for our purposes + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "6lImKbzK51MzWLwHh8tUM3UBBSBrLlgup/OOCGTvumM", + "ciphertext": "AwgAErABoRxwpMipdgiwXgu46rHiWQ0DmRj0qUlPrMraBUDk" + "leTnJRljpuc7IOhsYbLY3uo2WI0ab/ob41sV+3JEIhODJPqH" + "TK7cEZaIL+/up9e+dT9VGF5kRTWinzjkeqO8FU5kfdRjm+3w" + "0sy3o1OCpXXCfO+faPhbV/0HuK4ndx1G+myNfK1Nk/CxfMcT" + "BT+zDS/Df/QePAHVbrr9uuGB7fW8ogW/ulnydgZPRluusFGv" + "J3+cg9LoPpZPAmv5Me3ec7NtdlfN0oDZ0gk3TiNkkhsxDG9Y" + "YcNzl78USI0q8+kOV26Bu5dOBpU4WOuojXZHJlP5lMgdzLLl" + "EQ0", + "session_id": "IigqfNWLL+ez/Is+Duwp2s4HuCZhFG9b9CZKTYHtQ4A", + "device_id": "AHQDUSTAAA", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Add yet another person — we want to make this room not a 1:1 + # (as encrypted messages in a 1:1 currently have tweaks applied + # so it doesn't properly exercise the condition of all encrypted + # messages need to be high). + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another encrypted event + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "ciphertext": "AwgAEoABtEuic/2DF6oIpNH+q/PonzlhXOVho8dTv0tzFr5m" + "9vTo50yabx3nxsRlP2WxSqa8I07YftP+EKWCWJvTkg6o7zXq" + "6CK+GVvLQOVgK50SfvjHqJXN+z1VEqj+5mkZVN/cAgJzoxcH" + "zFHkwDPJC8kQs47IHd8EO9KBUK4v6+NQ1uE/BIak4qAf9aS/" + "kI+f0gjn9IY9K6LXlah82A/iRyrIrxkCkE/n0VfvLhaWFecC" + "sAWTcMLoF6fh1Jpke95mljbmFSpsSd/eEQw", + "device_id": "SRCFTWTHXO", + "session_id": "eMA+bhGczuTz1C5cJR1YbmrnnC6Goni4lbvS5vJ1nG4", + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "rC/XSIAiYrVGSuaHMop8/pTZbku4sQKBZwRwukgnN1c", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") + + def test_sends_high_priority_for_one_to_one_only(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message in a one-to-one room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Hi!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority — this is a one-to-one room + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Yet another user joins + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another event + self.helper.send(room, body="Welcome!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_mention(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message containing the user's display name. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other users join + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Oh, user, hello!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time with no mention + self.helper.send(room, body="Are you there?", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_atroom(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message that contains @room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room (as other_user so the power levels are compatible with + # other_user sending @room). + room = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The other users join + self.helper.join(room=room, user=user_id, tok=access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send( + room, + body="@room eeek! There's a spider on the table!", + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time as someone without the power of @room + self.helper.send( + room, body="@room the spider is gone", tok=yet_another_access_token + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 9ae6a87d7b..1f4b5ca2ac 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -15,13 +15,14 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent +from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def setUp(self): + def _get_evaluator(self, content): event = FrozenEvent( { "event_id": "$event_id", @@ -29,37 +30,74 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "sender": "@user:test", "state_key": "", "room_id": "@room:test", - "content": {"body": "foo bar baz"}, + "content": content, }, RoomVersions.V1, ) room_member_count = 0 sender_power_level = 0 power_levels = {} - self.evaluator = PushRuleEvaluatorForEvent( + return PushRuleEvaluatorForEvent( event, room_member_count, sender_power_level, power_levels ) def test_display_name(self): """Check for a matching display name in the body of the event.""" + evaluator = self._get_evaluator({"body": "foo bar baz"}) + condition = { "kind": "contains_display_name", } # Blank names are skipped. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + self.assertFalse(evaluator.matches(condition, "@user:test", "")) # Check a display name that doesn't match. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + self.assertFalse(evaluator.matches(condition, "@user:test", "not found")) # Check a display name which matches. - self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) # A display name that matches, but not a full word does not result in a match. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + self.assertFalse(evaluator.matches(condition, "@user:test", "ba")) # A display name should not be interpreted as a regular expression. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]")) # A display name with spaces should work fine. - self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) + self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + + def test_no_body(self): + """Not having a body shouldn't break the evaluator.""" + evaluator = self._get_evaluator({}) + + condition = { + "kind": "contains_display_name", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_invalid_body(self): + """A non-string body should not break the evaluator.""" + condition = { + "kind": "contains_display_name", + } + + for body in (1, True, {"foo": "bar"}): + evaluator = self._get_evaluator({"body": body}) + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_tweaks_for_actions(self): + """ + This tests the behaviour of tweaks_for_actions. + """ + + actions = [ + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + "notify", + ] + + self.assertEqual( + push_rule_evaluator.tweaks_for_actions(actions), + {"sound": "default", "highlight": True}, + ) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 9d4f0bbe44..ae60874ec3 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import attr @@ -26,8 +26,9 @@ from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, GenericWorkerServer, ) +from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest -from synapse.replication.http import streams +from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -35,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport +from tests.server import FakeTransport, render logger = logging.getLogger(__name__) @@ -64,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db = hs.get_datastore().db + self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler @@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") +class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests running multiple workers. + + Automatically handle HTTP replication requests from workers to master, + unlike `BaseStreamTestCase`. + """ + + servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + + def setUp(self): + super().setUp() + + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = self.hs.get_replication_streamer() + + store = self.hs.get_datastore() + self.database_pool = store.db_pool + + self.reactor.lookups["testserv"] = "1.2.3.4" + + self._worker_hs_to_resource = {} + + # When we see a connection attempt to the master replication listener we + # automatically set up the connection. This is so that tests don't + # manually have to go and explicitly set it up each time (plus sometimes + # it is impossible to write the handling explicitly in the tests). + self.reactor.add_tcp_client_callback( + "1.2.3.4", 8765, self._handle_http_replication_attempt + ) + + def create_test_json_resource(self): + """Overrides `HomeserverTestCase.create_test_json_resource`. + """ + # We override this so that it automatically registers all the HTTP + # replication servlets, without having to explicitly do that in all + # subclassses. + + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + + def make_worker_hs( + self, worker_app: str, extra_config: dict = {}, **kwargs + ) -> HomeServer: + """Make a new worker HS instance, correctly connecting replcation + stream to the master HS. + + Args: + worker_app: Type of worker, e.g. `synapse.app.federation_sender`. + extra_config: Any extra config to use for this instances. + **kwargs: Options that get passed to `self.setup_test_homeserver`, + useful to e.g. pass some mocks for things like `http_client` + + Returns: + The new worker HomeServer instance. + """ + + config = self._get_worker_hs_config() + config["worker_app"] = worker_app + config.update(extra_config) + + worker_hs = self.setup_test_homeserver( + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + **kwargs + ) + + store = worker_hs.get_datastore() + store.db_pool._db_pool = self.database_pool._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + # Set up a resource for the worker + resource = ReplicationRestResource(self.hs) + + for servlet in self.servlets: + servlet(worker_hs, resource) + + self._worker_hs_to_resource[worker_hs] = resource + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): + render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def _handle_http_replication_attempt(self): + """Handles a connection attempt to the master replication HTTP + listener. + """ + + # We should have at least one outbound connection attempt, where the + # last is one to the HTTP repication IP/port. + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # Note: at this point we've wired everything up, but we need to return + # before the data starts flowing over the connections as this is called + # inside `connecTCP` before the connection has been passed back to the + # code that requested the TCP connection. + + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel): # We need to manually stop the _PullToPushProducer. self._pull_to_push_producer.stop() + def checkPersistence(self, request, version): + """Check whether the connection can be re-used + """ + # We hijack this to always say no for ease of wiring stuff up in + # `handle_http_replication_attempt`. + request.responseHeaders.setRawHeaders(b"connection", [b"close"]) + return False + class _PullToPushProducer: """A push producer that wraps a pull producer. diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1a88c7fb80..561258a356 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0}, + {"highlight_count": 0, "unread_count": 0, "notify_count": 0}, ) self.persist( @@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1}, + {"highlight_count": 0, "unread_count": 0, "notify_count": 1}, ) self.persist( @@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2}, + {"highlight_count": 1, "unread_count": 0, "notify_count": 2}, ) def test_get_rooms_for_user_with_stream_ordering(self): @@ -366,7 +366,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = self.get_success(state_handler.compute_event_context(event)) - self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, + {user_id: actions for user_id, actions in push_actions}, + False, + ) ) return event, context diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 51bf0ef4e9..c9998e88e6 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -17,6 +17,7 @@ from typing import List, Optional from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase +from synapse.replication.tcp.commands import RdataCommand from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams.events import ( EventsStreamCurrentStateRow, @@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase): # also one state event state_event = self._inject_state_event() - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received self.assertEqual([], self.test_handler.received_rdata_rows) @@ -123,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase): OTHER_USER = "@other_user:localhost" # have the user join - inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -161,24 +159,21 @@ class EventsStreamTestCase(BaseStreamTestCase): # roll back all the state by de-modding the user prev_events = fork_point pls["users"][OTHER_USER] = 0 - pl_event = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + pl_event = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) # one more bit of state that doesn't get rolled back state2 = self._inject_state_event() - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received self.assertEqual([], self.test_handler.received_rdata_rows) @@ -277,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # have the users join for u in user_ids: - inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -315,23 +312,20 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_events = [] for u in user_ids: pls["users"][u] = 0 - e = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + e = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) prev_events = [e.event_id] pl_events.append(e) - # tell the notifier to catch up to avoid duplicate rows. - # workaround for https://github.com/matrix-org/synapse/issues/7360 - # FIXME remove this when the above is fixed - self.replicate() - # check we're testing what we think we are: no rows should yet have been # received self.assertEqual([], self.test_handler.received_rdata_rows) @@ -378,6 +372,64 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) + def test_backwards_stream_id(self): + """ + Test that RDATA that comes after the current position should be discarded. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # Generate an events. We inject them using inject_event so that they are + # not send out over replication until we call self.replicate(). + event = self._inject_test_event() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # We should have received the expected single row (as well as various + # cache invalidation updates which we ignore). + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + + # There should be a single received row. + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows[0] + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + # Reset the data. + self.test_handler.received_rdata_rows = [] + + # Save the current token for later. + worker_events_stream = self.worker_hs.get_replication_streams()["events"] + prev_token = worker_events_stream.current_token("master") + + # Manually send an old RDATA command, which should get dropped. This + # re-uses the row from above, but with an earlier stream token. + self.hs.get_tcp_replication().send_command( + RdataCommand("events", "master", 1, row) + ) + + # No updates have been received (because it was discard as old). + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + self.assertEqual(len(received_rows), 0) + + # Ensure the stream has not gone backwards. + current_token = worker_events_stream.current_token("master") + self.assertGreaterEqual(current_token, prev_token) + event_count = 0 def _inject_test_event( @@ -390,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase): body = "event %i" % (self.event_count,) self.event_count += 1 - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_event", - content={"body": body}, - **kwargs + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) ) def _inject_state_event( @@ -415,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase): if body is None: body = "state event %s" % (state_key,) - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_state_event", - state_key=state_key, - content={"body": body}, + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) ) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index fd62b26356..5acfb3e53e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -16,10 +16,15 @@ from mock import Mock from synapse.handlers.typing import RoomMember from synapse.replication.tcp.streams import TypingStream +from synapse.util.caches.stream_change_cache import StreamChangeCache from tests.replication._base import BaseStreamTestCase USER_ID = "@feeling:blue" +USER_ID_2 = "@da-ba-dee:blue" + +ROOM_ID = "!bar:blue" +ROOM_ID_2 = "!foo:blue" class TypingStreamTestCase(BaseStreamTestCase): @@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase): def test_typing(self): typing = self.hs.get_typing_handler() - room_id = "!bar:blue" - self.reconnect() - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) self.reactor.advance(0) @@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow - self.assertEqual(room_id, row.room_id) + self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) # Now let's disconnect and insert some data. @@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.test_handler.on_rdata.reset_mock() - typing._push_update(member=RoomMember(room_id, USER_ID), typing=False) + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False) self.test_handler.on_rdata.assert_not_called() @@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] - self.assertEqual(room_id, row.room_id) + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual([], row.user_ids) + + def test_reset(self): + """ + Test what happens when a typing stream resets. + + This is emulated by jumping the stream ahead, then reconnecting (which + sends the proper position and RDATA). + """ + typing = self.hs.get_typing_handler() + + self.reconnect() + + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) + + self.reactor.advance(0) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] # type: TypingStream.TypingStreamRow + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual([USER_ID], row.user_ids) + + # Push the stream forward a bunch so it can be reset. + for i in range(100): + typing._push_update( + member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True + ) + self.reactor.advance(0) + + # Disconnect. + self.disconnect() + + # Reset the typing handler + self.hs.get_replication_streams()["typing"].last_token = 0 + self.hs.get_tcp_replication()._streams["typing"].last_token = 0 + typing._latest_room_serial = 0 + typing._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", typing._latest_room_serial + ) + typing._reset() + + # Reconnect. + self.reconnect() + self.pump(0.1) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + # Reset the test code. + self.test_handler.on_rdata.reset_mock() + self.test_handler.on_rdata.assert_not_called() + + # Push additional data. + typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) + self.reactor.advance(0) + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] + self.assertEqual(ROOM_ID_2, row.room_id) self.assertEqual([], row.user_ids) + + # The token should have been reset. + self.assertEqual(token, 1) diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py new file mode 100644 index 0000000000..86c03fd89c --- /dev/null +++ b/tests/replication/test_client_reader_shard.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.constants import LoginType +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha import register + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker +from tests.server import FakeChannel + +logger = logging.getLogger(__name__) + + +class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): + """Base class for tests of the replication streams""" + + servlets = [register.register_servlets] + + def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def test_register_single_worker(self): + """Test that registration works when using a single client reader worker. + """ + worker_hs = self.make_worker_hs("synapse.app.client_reader") + + request_1, channel_1 = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs, request_1) + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + request_2, channel_2 = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs, request_2) + self.assertEqual(request_2.code, 200) + + # We're given a registered user. + self.assertEqual(channel_2.json_body["user_id"], "@user:test") + + def test_register_multi_worker(self): + """Test that registration works when using multiple client reader workers. + """ + worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") + worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") + + request_1, channel_1 = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs_1, request_1) + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + request_2, channel_2 = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) # type: SynapseRequest, FakeChannel + self.render_on_worker(worker_hs_2, request_2) + self.assertEqual(request_2.code, 200) + + # We're given a registered user. + self.assertEqual(channel_2.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 5448d9f0dc..23be1167a3 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + return hs def test_federation_ack_sent(self): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py new file mode 100644 index 0000000000..8b4982ecb1 --- /dev/null +++ b/tests/replication/test_federation_sender_shard.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from synapse.api.constants import EventTypes, Membership +from synapse.events.builder import EventBuilderFactory +from synapse.rest.admin import register_servlets_for_client_rest_resource +from synapse.rest.client.v1 import login, room +from synapse.types import UserID, create_requester + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import make_awaitable + +logger = logging.getLogger(__name__) + + +class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + def default_config(self): + conf = super().default_config() + conf["send_federation"] = False + return conf + + def test_send_event_single_sender(self): + """Test that using a single federation sender worker correctly sends a + new event. + """ + mock_client = Mock(spec=["put_json"]) + mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) + + self.make_worker_hs( + "synapse.app.federation_sender", + {"send_federation": True}, + http_client=mock_client, + ) + + user = self.register_user("user", "pass") + token = self.login("user", "pass") + + room = self.create_room_with_remote_server(user, token) + + mock_client.put_json.reset_mock() + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + # Assert that the event was sent out over federation. + mock_client.put_json.assert_called() + self.assertEqual(mock_client.put_json.call_args[0][0], "other_server") + self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus")) + + def test_send_event_sharded(self): + """Test that using two federation sender workers correctly sends + new events. + """ + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client1, + ) + + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client2, + ) + + user = self.register_user("user2", "pass") + token = self.login("user2", "pass") + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) + + def test_send_typing_sharded(self): + """Test that using two federation sender workers correctly sends + new typing EDUs. + """ + mock_client1 = Mock(spec=["put_json"]) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client1, + ) + + mock_client2 = Mock(spec=["put_json"]) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + self.make_worker_hs( + "synapse.app.federation_sender", + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + }, + http_client=mock_client2, + ) + + user = self.register_user("user3", "pass") + token = self.login("user3", "pass") + + typing_handler = self.hs.get_typing_handler() + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() # type: ignore[attr-defined] + mock_client2.reset_mock() # type: ignore[attr-defined] + + self.get_success( + typing_handler.started_typing( + target_user=UserID.from_string(user), + requester=create_requester(user), + room_id=room, + timeout=20000, + ) + ) + + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py new file mode 100644 index 0000000000..2bdc6edbb1 --- /dev/null +++ b/tests/replication/test_pusher_shard.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + + +class PusherShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks pusher sharding works + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["start_pushers"] = False + return conf + + def _create_pusher_and_send_msg(self, localpart): + # Create a user that will get push notifications + user_id = self.register_user(localpart, "pass") + access_token = self.login(localpart, "pass") + + # Register a pusher + user_dict = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_dict["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "https://push.example.com/push"}, + ) + ) + + self.pump() + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + response = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + return event_id + + def test_send_push_single_worker(self): + """Test that registration works when using a pusher worker. + """ + http_client_mock = Mock(spec_set=["post_json_get_json"]) + http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + {"start_pushers": True}, + proxied_http_client=http_client_mock, + ) + + event_id = self._create_pusher_and_send_msg("user") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + def test_send_push_multiple_workers(self): + """Test that registration works when using sharded pusher workers. + """ + http_client_mock1 = Mock(spec_set=["post_json_get_json"]) + http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher1", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock1, + ) + + http_client_mock2 = Mock(spec_set=["post_json_get_json"]) + http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher2", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock2, + ) + + # We choose a user name that we know should go to pusher1. + event_id = self._create_pusher_and_send_msg("user2") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_called_once() + http_client_mock2.post_json_get_json.assert_not_called() + self.assertEqual( + http_client_mock1.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock1.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + http_client_mock1.post_json_get_json.reset_mock() + http_client_mock2.post_json_get_json.reset_mock() + + # Now we choose a user name that we know should go to pusher2. + event_id = self._create_pusher_and_send_msg("user4") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_not_called() + http_client_mock2.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock2.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock2.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 977615ebef..0f1144fe1e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + async def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + return await make_deferred_yieldable(d) client = Mock() client.get_file = get_file @@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): return hs + def _ensure_quarantined(self, admin_user_tok, server_and_media_id): + """Ensure a piece of media is quarantined when trying to access it.""" + request, channel = self.make_request( + "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id + ), + ) + def test_quarantine_media_requires_admin(self): self.register_user("nonadmin", "pass", admin=False) non_admin_user_tok = self.login("nonadmin", "pass") @@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Attempt to access the media - request, channel = self.make_request( - "GET", - server_name_and_media_id, - shorthand=False, - access_token=admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_name_and_media_id - ), - ) + self._ensure_quarantined(admin_user_tok, server_name_and_media_id) def test_quarantine_all_media_in_room(self, override_url_template=None): self.register_user("room_admin", "pass", admin=True) @@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_and_media_id_2 = mxc_2[6:] # Test that we cannot download any of the media anymore - request, channel = self.make_request( - "GET", - server_and_media_id_1, - shorthand=False, - access_token=non_admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_1 - ), - ) - - request, channel = self.make_request( - "GET", - server_and_media_id_2, - shorthand=False, - access_token=non_admin_user_tok, - ) - request.render(self.download_resource) - self.pump(1.0) - - # Should be quarantined - self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_2 - ), - ) + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + self._ensure_quarantined(admin_user_tok, server_and_media_id_2) - def test_quaraantine_all_media_in_room_deprecated_api_path(self): + def test_quarantine_all_media_in_room_deprecated_api_path(self): # Perform the above test with the deprecated API path self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s") @@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Attempt to access each piece of media + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + self._ensure_quarantined(admin_user_tok, server_and_media_id_2) + + def test_cannot_quarantine_safe_media(self): + self.register_user("user_admin", "pass", admin=True) + admin_user_tok = self.login("user_admin", "pass") + + non_admin_user = self.register_user("user_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("user_nonadmin", "pass") + + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + + # Extract media IDs + server_and_media_id_1 = response_1["content_uri"][6:] + server_and_media_id_2 = response_2["content_uri"][6:] + + # Mark the second item as safe from quarantine. + _, media_id_2 = server_and_media_id_2.split("/") + self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + + # Quarantine all media by this user + url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( + non_admin_user + ) request, channel = self.make_request( - "GET", - server_and_media_id_1, - shorthand=False, - access_token=non_admin_user_tok, + "POST", url.encode("ascii"), access_token=admin_user_tok, ) - request.render(self.download_resource) + self.render(request) self.pump(1.0) - - # Should be quarantined + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - 404, - int(channel.code), - msg=( - "Expected to receive a 404 on accessing quarantined media: %s" - % server_and_media_id_1, - ), + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 1}, + "Expected 1 quarantined item", ) + # Attempt to access each piece of media, the first should fail, the + # second should succeed. + self._ensure_quarantined(admin_user_tok, server_and_media_id_1) + # Attempt to access each piece of media request, channel = self.make_request( "GET", @@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request.render(self.download_resource) self.pump(1.0) - # Should be quarantined + # Shouldn't be quarantined self.assertEqual( - 404, + 200, int(channel.code), msg=( - "Expected to receive a 404 on accessing quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 54cd24bf64..408c568a27 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1,1007 +1,1500 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Dirk Klimpel -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import urllib.parse -from typing import List, Optional - -from mock import Mock - -import synapse.rest.admin -from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room - -from tests import unittest - -"""Tests admin REST events for /rooms paths.""" - - -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - -class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - directory.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - # Create user - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_list_rooms(self): - """Test that we can list rooms""" - # Create 3 test rooms - total_rooms = 3 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - - # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) - - # Check that response json body contains a "rooms" key - self.assertTrue( - "rooms" in channel.json_body, - msg="Response body does not " "contain a 'rooms' key", - ) - - # Check that 3 rooms were returned - self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) - - # Check their room_ids match - returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] - self.assertEqual(room_ids, returned_room_ids) - - # Check that all fields are available - for r in channel.json_body["rooms"]: - self.assertIn("name", r) - self.assertIn("canonical_alias", r) - self.assertIn("joined_members", r) - self.assertIn("joined_local_members", r) - self.assertIn("version", r) - self.assertIn("creator", r) - self.assertIn("encryption", r) - self.assertIn("federatable", r) - self.assertIn("public", r) - self.assertIn("join_rules", r) - self.assertIn("guest_access", r) - self.assertIn("history_visibility", r) - self.assertIn("state_events", r) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # Should be 0 as we aren't paginating - self.assertEqual(channel.json_body["offset"], 0) - - # Check that the prev_batch parameter is not present - self.assertNotIn("prev_batch", channel.json_body) - - # We shouldn't receive a next token here as there's no further rooms to show - self.assertNotIn("next_batch", channel.json_body) - - def test_list_rooms_pagination(self): - """Test that we can get a full list of rooms through pagination""" - # Create 5 test rooms - total_rooms = 5 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Set the name of the rooms so we get a consistent returned ordering - for idx, room_id in enumerate(room_ids): - self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - returned_room_ids = [] - start = 0 - limit = 2 - - run_count = 0 - should_repeat = True - while should_repeat: - run_count += 1 - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( - start, - limit, - "name", - ) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - self.assertTrue("rooms" in channel.json_body) - for r in channel.json_body["rooms"]: - returned_room_ids.append(r["room_id"]) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # We're only getting 2 rooms each page, so should be 2 * last run_count - self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) - - if run_count > 1: - # Check the value of prev_batch is correct - self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) - - if "next_batch" not in channel.json_body: - # We have reached the end of the list - should_repeat = False - else: - # Make another query with an updated start value - start = channel.json_body["next_batch"] - - # We should've queried the endpoint 3 times - self.assertEqual( - run_count, - 3, - msg="Should've queried 3 times for 5 rooms with limit 2 per query", - ) - - # Check that we received all of the room ids - self.assertEqual(room_ids, returned_room_ids) - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - def test_correct_room_attributes(self): - """Test the correct attributes for a room are returned""" - # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - test_alias = "#test:test" - test_room_name = "something" - - # Have another user join the room - user_2 = self.register_user("user4", "pass") - user_tok_2 = self.login("user4", "pass") - self.helper.join(room_id, user_2, tok=user_tok_2) - - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=self.admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=self.admin_user_tok, - ) - - # Set a name for the room - self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that only one room was returned - self.assertEqual(len(rooms), 1) - - # And that the value of the total_rooms key was correct - self.assertEqual(channel.json_body["total_rooms"], 1) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that all provided attributes are set - r = rooms[0] - self.assertEqual(room_id, r["room_id"]) - self.assertEqual(test_room_name, r["name"]) - self.assertEqual(test_alias, r["canonical_alias"]) - - def test_room_list_sort_order(self): - """Test room list sort ordering. alphabetical name versus number of members, - reversing the order, etc. - """ - - def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % ( - urllib.parse.quote(test_alias), - ) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=admin_user_tok, - ) - - def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, - ): - """Request the list of rooms in a certain order. Assert that order is what - we expect - - Args: - order_type: The type of ordering to give the server - expected_room_list: The list of room_ids in the order we expect to get - back from the server - """ - # Request the list of rooms in the given order - url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) - if reverse: - url += "&dir=b" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check for the correct total_rooms value - self.assertEqual(channel.json_body["total_rooms"], 3) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that rooms were returned in alphabetical order - returned_order = [r["room_id"] for r in rooms] - self.assertListEqual(expected_room_list, returned_order) # order is checked - - # Create 3 test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C - self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, - ) - - # Set room canonical room aliases - _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) - _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) - _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) - - # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 - user_1 = self.register_user("bob1", "pass") - user_1_tok = self.login("bob1", "pass") - self.helper.join(room_id_2, user_1, tok=user_1_tok) - - user_2 = self.register_user("bob2", "pass") - user_2_tok = self.login("bob2", "pass") - self.helper.join(room_id_3, user_2, tok=user_2_tok) - - user_3 = self.register_user("bob3", "pass") - user_3_tok = self.login("bob3", "pass") - self.helper.join(room_id_3, user_3, tok=user_3_tok) - - # Test different sort orders, with forward and reverse directions - _order_test("name", [room_id_1, room_id_2, room_id_3]) - _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) - _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) - _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) - _order_test( - "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True - ) - - _order_test("version", [room_id_1, room_id_2, room_id_3]) - _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("creator", [room_id_1, room_id_2, room_id_3]) - _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("encryption", [room_id_1, room_id_2, room_id_3]) - _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("federatable", [room_id_1, room_id_2, room_id_3]) - _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("public", [room_id_1, room_id_2, room_id_3]) - # Different sort order of SQlite and PostreSQL - # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) - _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) - _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) - - _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) - _order_test( - "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True - ) - - _order_test("state_events", [room_id_3, room_id_2, room_id_1]) - _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - - def test_search_term(self): - """Test that searching for a room works correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - def _search_test( - expected_room_id: Optional[str], - search_term: str, - expected_http_code: int = 200, - ): - """Search for a room and check that the returned room's id is a match - - Args: - expected_room_id: The room_id expected to be returned by the API. Set - to None to expect zero results for the search - search_term: The term to search for room names with - expected_http_code: The expected http code for the request - """ - url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - - if expected_http_code != 200: - return - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that the expected number of rooms were returned - expected_room_count = 1 if expected_room_id else 0 - self.assertEqual(len(rooms), expected_room_count) - self.assertEqual(channel.json_body["total_rooms"], expected_room_count) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - if expected_room_id: - # Check that the first returned room id is correct - r = rooms[0] - self.assertEqual(expected_room_id, r["room_id"]) - - # Perform search tests - _search_test(room_id_1, "something") - _search_test(room_id_1, "thing") - - _search_test(room_id_2, "else") - _search_test(room_id_2, "se") - - _search_test(None, "foo") - _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) - - def test_single_room(self): - """Test that a single room can be requested correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - self.assertIn("room_id", channel.json_body) - self.assertIn("name", channel.json_body) - self.assertIn("canonical_alias", channel.json_body) - self.assertIn("joined_members", channel.json_body) - self.assertIn("joined_local_members", channel.json_body) - self.assertIn("version", channel.json_body) - self.assertIn("creator", channel.json_body) - self.assertIn("encryption", channel.json_body) - self.assertIn("federatable", channel.json_body) - self.assertIn("public", channel.json_body) - self.assertIn("join_rules", channel.json_body) - self.assertIn("guest_access", channel.json_body) - self.assertIn("history_visibility", channel.json_body) - self.assertIn("state_events", channel.json_body) - - self.assertEqual(room_id_1, channel.json_body["room_id"]) - - -class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.creator = self.register_user("creator", "test") - self.creator_tok = self.login("creator", "test") - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - - self.public_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=True - ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) - - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error 403 is returned. - """ - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.second_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_invalid_parameter(self): - """ - If a parameter is missing, return an error - """ - body = json.dumps({"unknown_parameter": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - - def test_local_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ - body = json.dumps({"user_id": "@unknown:test"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_remote_user(self): - """ - Check that only local user can join rooms. - """ - body = json.dumps({"user_id": "@not:exist.bla"}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "This endpoint can only be used with local users", - channel.json_body["error"], - ) - - def test_room_does_not_exist(self): - """ - Check that unknown rooms/server return error 404. - """ - body = json.dumps({"user_id": self.second_user_id}) - url = "/_synapse/admin/v1/join/!unknown:test" - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No known servers", channel.json_body["error"]) - - def test_room_is_not_valid(self): - """ - Check that invalid room names, return an error 400. - """ - body = json.dumps({"user_id": self.second_user_id}) - url = "/_synapse/admin/v1/join/invalidroom" - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "invalidroom was not legal room ID or room alias", - channel.json_body["error"], - ) - - def test_join_public_room(self): - """ - Test joining a local user to a public room with "JoinRules.PUBLIC" - """ - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - self.url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.public_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - - def test_join_private_room_if_not_member(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE" - when server admin is not member of this room. - """ - private_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False - ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - def test_join_private_room_if_member(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE", - when server admin is member of this room. - """ - private_room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False - ) - self.helper.invite( - room=private_room_id, - src=self.creator, - targ=self.admin_user, - tok=self.creator_tok, - ) - self.helper.join( - room=private_room_id, user=self.admin_user, tok=self.admin_user_tok - ) - - # Validate if server admin is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - - # Join user to room. - - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - - def test_join_private_room_if_owner(self): - """ - Test joining a local user to a private room with "JoinRules.INVITE", - when server admin is owner of this room. - """ - private_room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=False - ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) - body = json.dumps({"user_id": self.second_user_id}) - - request, channel = self.make_request( - "POST", - url, - content=body.encode(encoding="utf_8"), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["room_id"]) - - # Validate if user is a member of the room - - request, channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, - ) - self.render(request) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse +from typing import List, Optional + +from mock import Mock + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import directory, events, login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, body="foo", tok=self.other_user_token, expect_code=403 + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class DeleteRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + url = "/_synapse/admin/v1/rooms/!unknown:test/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + url = "/_synapse/admin/v1/rooms/invalidroom/delete" + + request, channel = self.make_request( + "POST", url, json.dumps({}), access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom is not a legal room ID", channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + body = json.dumps({"new_room_user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("kicked_users", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + body = json.dumps({"new_room_user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "User must be our own: @not:exist.bla", channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + body = json.dumps({"block": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + body = json.dumps({"purge": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": True, "purge": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": True}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("new_room_id", channel.json_body) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["new_room_id"], user_id=self.other_user + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id, expect=True): + """Assert that the room is blocked or not + """ + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id): + """Assert there is now no longer anyone in the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id, user_id): + """Test that user is member of the room + """ + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id): + """Test that the following tables have been purged of all rows related to the room. + """ + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + self.assertIn("joined_local_members", r) + self.assertIn("version", r) + self.assertIn("creator", r) + self.assertIn("encryption", r) + self.assertIn("federatable", r) + self.assertIn("public", r) + self.assertIn("join_rules", r) + self.assertIn("guest_access", r) + self.assertIn("history_visibility", r) + self.assertIn("state_events", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "name", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical name versus number of members, + reversing the order, etc. + """ + + def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % ( + urllib.parse.quote(test_alias), + ) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=admin_user_tok, + ) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room canonical room aliases + _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + # Test different sort orders, with forward and reverse directions + _order_test("name", [room_id_1, room_id_2, room_id_3]) + _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) + _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) + _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) + _order_test( + "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("version", [room_id_1, room_id_2, room_id_3]) + _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("creator", [room_id_1, room_id_2, room_id_3]) + _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("encryption", [room_id_1, room_id_2, room_id_3]) + _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("federatable", [room_id_1, room_id_2, room_id_3]) + _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("public", [room_id_1, room_id_2, room_id_3]) + # Different sort order of SQlite and PostreSQL + # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) + _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) + _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) + _order_test( + "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("state_events", [room_id_3, room_id_2, room_id_1]) + _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) + + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + + def test_room_members(self): + """Test that room members can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + # Have another user join the room + user_2 = self.register_user("bar", "pass") + user_tok_2 = self.login("bar", "pass") + self.helper.join(room_id_1, user_2, tok=user_tok_2) + self.helper.join(room_id_2, user_2, tok=user_tok_2) + + # Have another user join the room + user_3 = self.register_user("foobar", "pass") + user_tok_3 = self.login("foobar", "pass") + self.helper.join(room_id_2, user_3, tok=user_tok_3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertCountEqual( + ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] + ) + self.assertEqual(channel.json_body["total"], 3) + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cca5f548e6..160c630235 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -27,6 +27,7 @@ from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import sync from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -335,7 +336,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() # Set monthly active users to the limit - store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + store.get_monthly_active_count = Mock( + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit self.get_failure( @@ -588,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -628,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -857,6 +860,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + def test_reactivate_user(self): + """ + Test reactivating another user. + """ + + # Deactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Attempt to reactivate the user (without a password). + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=json.dumps({"deactivated": False, "password": "foo"}).encode( + encoding="utf_8" + ), + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + def test_set_user_as_admin(self): """ Test setting the admin flag on a user. diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 95475bb651..7d3773ff78 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase): } self.hs = self.setup_test_homeserver(config=config) + return self.hs def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_retention_state_event(self): - """Tests that the server configuration can limit the values a user can set to the - room's retention policy. + self.store = self.hs.get_datastore() + self.serializer = self.hs.get_event_client_serializer() + self.clock = self.hs.get_clock() + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. """ room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_day_ms * 4}, + body={"max_lifetime": lifetime}, tok=self.token, - expect_code=400, ) + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_with_state_event_outside_allowed(self): + """Tests that the server configuration can override the policy for a room when + running the purge jobs. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set a max_lifetime higher than the maximum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_hour_ms}, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, - expect_code=400, ) - def test_retention_event_purged_with_state_event(self): - """Tests that expired events are correctly purged when the room's retention policy - is defined by a state event. - """ - room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Check that the event is purged after waiting for the maximum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 1.5) - # Set the room's retention period to 2 days. - lifetime = one_day_ms * 2 + # Set a max_lifetime lower than the minimum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": lifetime}, + body={"max_lifetime": one_hour_ms}, tok=self.token, ) - self._test_retention_event_purged(room_id, one_day_ms * 1.5) + # Check that the event is purged after waiting for the minimum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 0.5) def test_retention_event_purged_without_state_event(self): """Tests that expired events are correctly purged when the room's retention policy @@ -126,7 +139,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by anothe 2 days. After this, the first event should be + # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) @@ -140,11 +153,33 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id, increment): + def _test_retention_event_purged(self, room_id: str, increment: float): + """Run the following test scenario to test the message retention policy support: + + 1. Send event 1 + 2. Increment time by `increment` + 3. Send event 2 + 4. Increment time by `increment` + 5. Check that event 1 has been purged + 6. Check that event 2 has not been purged + 7. Check that state events that were sent before event 1 aren't purged. + The main reason for sending a second event is because currently Synapse won't + purge the latest message in a room because it would otherwise result in a lack of + forward extremities for this room. It's also a good thing to ensure the purge jobs + aren't too greedy and purge messages they shouldn't. + + Args: + room_id: The ID of the room to test retention in. + increment: The number of milliseconds to advance the clock each time. Must be + defined so that events in the room aren't purged if they are `increment` + old but are purged if they are `increment * 2` old. + """ # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( - message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + message_handler.get_room_data( + self.user_id, room_id, EventTypes.Create, state_key="" + ) ) # Send a first event to the room. This is the event we'll want to be purged at the @@ -154,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): expired_event_id = resp.get("event_id") # Check that we can retrieve the event. - expired_event = self.get_event(room_id, expired_event_id) + expired_event = self.get_event(expired_event_id) self.assertEqual( expired_event.get("content", {}).get("body"), "1", expired_event ) @@ -172,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase): # one should still be kept. self.reactor.advance(increment / 1000) - # Check that the event has been purged from the database. - self.get_event(room_id, expired_event_id, expected_code=404) + # Check that the first event has been purged from the database, i.e. that we + # can't retrieve it anymore, because it has expired. + self.get_event(expired_event_id, expect_none=True) - # Check that the event that hasn't been purged can still be retrieved. - valid_event = self.get_event(room_id, valid_event_id) + # Check that the event that hasn't expired can still be retrieved. + valid_event = self.get_event(valid_event_id) self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) # Check that we can still access state events that were sent before the event that # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, room_id, event_id, expected_code=200): - url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + def get_event(self, event_id, expect_none=False): + event = self.get_success(self.store.get_event(event_id, allow_none=True)) - request, channel = self.make_request("GET", url, access_token=self.token) - self.render(request) + if expect_none: + self.assertIsNone(event) + return {} - self.assertEqual(channel.code, expected_code, channel.result) + self.assertIsNotNone(event) - return channel.json_body + time_now = self.clock.time_msec() + serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + + return serialized class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py new file mode 100644 index 0000000000..dfe4bf7762 --- /dev/null +++ b/tests/rest/client/test_shadow_banned.py @@ -0,0 +1,312 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock, patch + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import directory, login, profile, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest + + +class _ShadowBannedBase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + # Create two users, one of which is shadow-banned. + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class RoomTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) + + def test_message(self): + """Messages from shadow-banned users don't actually get sent.""" + + room_id = self.helper.create_room_as( + self.other_user_id, tok=self.other_access_token + ) + + # The user should be in the room. + self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) + + # Sending a message should complete successfully. + result = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "with right label"}, + tok=self.banned_access_token, + ) + self.assertIn("event_id", result) + event_id = result["event_id"] + + latest_events = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + self.assertNotIn(event_id, latest_events) + + def test_upgrade(self): + """A room upgrade should fail, but look like it succeeded.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), + {"new_version": "6"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + # A new room_id should be returned. + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(new_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + def test_typing(self): + """Typing notifications should not be propagated into the room.""" + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # There should be no typing events. + event_source = self.hs.get_event_sources().sources["typing"] + self.assertEquals(event_source.get_current_key(), 0) + + # The other user can join and send typing events. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.other_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.other_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # These appear in the room. + self.assertEquals(event_source.get_current_key(), 1) + events = self.get_success( + event_source.get_new_events(from_key=0, room_ids=[room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": room_id, + "content": {"user_ids": [self.other_user_id]}, + } + ], + ) + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class ProfileTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def test_displayname(self): + """Profile changes should succeed, but don't end up in a room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), + {"displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertEqual(channel.json_body, {}) + + # The user's display name should be updated. + request, channel = self.make_request( + "GET", "/profile/%s/displayname" % (self.banned_user_id,) + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["displayname"], new_display_name) + + # But the display name in the room should not be. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) + + def test_room_displayname(self): + """Changes to state events for a room should be processed, but not end up in the room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (room_id, self.banned_user_id), + {"membership": "join", "displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertIn("event_id", channel.json_body) + + # The display name in the room should not be changed. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py index 7167fc56b6..8c24add530 100644 --- a/tests/rest/client/third_party_rules.py +++ b/tests/rest/client/third_party_rules.py @@ -19,7 +19,7 @@ from synapse.rest.client.v1 import login, room from tests import unittest -class ThirdPartyRulesTestModule(object): +class ThirdPartyRulesTestModule: def __init__(self, config): pass diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 9033f09fd2..2668662c9e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -398,7 +392,7 @@ class CASTestCase(unittest.HomeserverTestCase): </cas:serviceResponse> """ % cas_user_id - ) + ).encode("utf-8") mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw @@ -514,19 +508,22 @@ class JWTTestCase(unittest.HomeserverTestCase): ] jwt_secret = "secret" + jwt_algorithm = "HS256" def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() self.hs.config.jwt_enabled = True self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = "HS256" + self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, "HS256").decode("ascii") + return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") def jwt_login(self, *args): - params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + params = json.dumps( + {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + ) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) return channel @@ -544,35 +541,126 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) def test_login_jwt_expired(self): channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "JWT expired") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Signature has expired" + ) def test_login_jwt_not_before(self): now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: The token is not yet valid (nbf)", + ) def test_login_no_sub(self): channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "issuer": "test-issuer", + } + } + ) + def test_login_iss(self): + """Test validating the issuer claim.""" + # A valid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid issuer" + ) + + # Not providing an issuer. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "iss" claim', + ) + + def test_login_iss_no_config(self): + """Test providing an issuer claim without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "audiences": ["test-audience"], + } + } + ) + def test_login_aud(self): + """Test validating the audience claim.""" + # A valid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + # Not providing an audience. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "aud" claim', + ) + + def test_login_aud_no_config(self): + """Test providing an audience without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + def test_login_no_token(self): - params = json.dumps({"type": "m.login.jwt"}) + params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -640,7 +728,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return jwt.encode(token, secret, "RS256").decode("ascii") def jwt_login(self, *args): - params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + params = json.dumps( + {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + ) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) return channel @@ -652,6 +742,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_invalid_signature(self): channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"401", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") - self.assertEqual(channel.json_body["error"], "Invalid JWT") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0fdff79aa7..3c66255dac 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def test_put_presence_disabled(self): """ - PUT to the status endpoint with use_presence disbled will NOT call + PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. """ self.hs.config.use_presence = False diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 8df58b4a63..ace0a3c08d 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase): profile_handler=self.mock_handler, ) - def _get_user_by_req(request=None, allow_guest=False): - return defer.succeed(synapse.types.create_requester(myid)) + async def _get_user_by_req(request=None, allow_guest=False): + return synapse.types.create_requester(myid) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4886bbb401..0a567b032f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,18 +19,16 @@ """Tests REST events for /rooms paths.""" import json +from urllib import parse as urlparse from mock import Mock -from six.moves.urllib import parse as urlparse - -from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias +from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string from tests import unittest @@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None self.hs.get_datastore().insert_client_ip = _insert_client_ip @@ -677,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for i in range(3): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join as many rooms as the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + # Let some time for the rate-limiter to forget about our multi-join. + self.reactor.advance(2) + # Add one to make sure we're joined to more rooms than the config allows us to + # join in a second. + room_ids.append(self.helper.create_room_as(self.user_id)) + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + self.get_success( + store.create_profile(UserID.from_string(self.user_id).localpart) + ) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.render(request) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for i in range(4): + request, channel = self.make_request("POST", path % room_id, {}) + self.render(request) + self.assertEquals(channel.code, 200) + + class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 18260bb90e..94d2bf2eb1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_handlers().federation_handler = Mock() - def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, @@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 22d734e763..afaf9f7b85 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -30,7 +30,7 @@ from tests.server import make_request, render @attr.s -class RestHelper(object): +class RestHelper: """Contains extra helper functions to quickly and clearly perform a given REST action, which isn't the focus of the test. """ @@ -39,7 +39,9 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator=None, is_public=True, tok=None): + def create_room_as( + self, room_creator=None, is_public=True, tok=None, expect_code=200, + ): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -54,9 +56,11 @@ class RestHelper(object): ) render(request, self.resource, self.hs.get_reactor()) - assert channel.result["code"] == b"200", channel.result + assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - return channel.json_body["room_id"] + + if expect_code == 200: + return channel.json_body["room_id"] def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( @@ -88,7 +92,28 @@ class RestHelper(object): expect_code=expect_code, ) - def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): + def change_membership( + self, + room: str, + src: str, + targ: str, + membership: str, + extra_data: dict = {}, + tok: Optional[str] = None, + expect_code: int = 200, + ) -> None: + """ + Send a membership state event into a room. + + Args: + room: The ID of the room to send to + src: The mxid of the event sender + targ: The mxid of the event's target. The state key + membership: The type of membership event + extra_data: Extra information to include in the content of the event + tok: The user access token to use + expect_code: The expected HTTP response code + """ temp_id = self.auth_user_id self.auth_user_id = src @@ -97,6 +122,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok data = {"membership": membership} + data.update(extra_data) request, channel = make_request( self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 3ab611f618..152a5182fa 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + def test_basic_password_reset_canonicalise_email(self): + """Test basic password reset flow + Request password reset with different spelling + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email_profile = "test@example.com" + email_passwort_reset = "TEST@EXAMPLE.COM" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email_profile, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email_passwort_reset, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + def test_cant_reset_password_without_clicking_link(self): """Test that we do actually need to click the link in the email """ @@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_email(self): - """Test adding an email to profile - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) + def test_add_valid_email(self): + self.get_success(self._add_email(self.email, self.email)) - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() + def test_add_valid_email_second_time(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) - self._validate_token(link) + def test_add_valid_email_second_time_canonicalise(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, + def test_add_email_no_at(self): + self.get_success( + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_add_email_two_at(self): + self.get_success( + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + def test_add_email_bad_format(self): + self.get_success( + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + def test_add_email_domain_to_lower(self): + self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + + def test_add_email_domain_with_umlaut(self): + self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + + def test_add_email_address_casefold(self): + self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + + def test_address_trim(self): + self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed @@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] + def _request_token_invalid_email( + self, email, expected_errcode, expected_error, client_secret="foobar", + ): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(expected_errcode, channel.json_body["errcode"]) + self.assertEqual(expected_error, channel.json_body["error"]) + def _validate_token(self, link): # Remove the host path = link.replace("https://example.com", "") @@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): assert match, "Could not find link in email" return match.group(0) + + def _add_email(self, request_email, expected_email): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(request_email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index e0e9e94fbf..de00350580 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.api.errors import Codes from synapse.rest.client.v2_alpha import filter @@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): - filter_id = self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER + filter_id = defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER + ) ) self.reactor.advance(1) filter_id = filter_id.result diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 7deaf5b24a..2fc3a60fc5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) + @override_config({"enable_registration": False}) def test_POST_disabled_registration(self): - self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) @@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index c7e5859970..99c9f4e928 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -15,8 +15,7 @@ import itertools import json - -import six +import urllib from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -100,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API corectly the latest relations. + """Tests that calling pagination API correctly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) @@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( - channel.json_body.get("next_batch"), six.string_types, channel.json_body + channel.json_body.get("next_batch"), str, channel.json_body ) def test_repeated_paginate_relations(self): @@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) for _ in range(20): from_token = "" if prev_token: @@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): query = "" if key: - query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) + query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) original_id = parent_id if parent_id else self.parent_id diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py new file mode 100644 index 0000000000..5ae72fd008 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import shared_rooms + +from tests import unittest + + +class UserSharedRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserSharedRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + shared_rooms.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def _get_shared_rooms(self, token, other_user): + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + % other_user, + access_token=token, + ) + self.render(request) + return request, channel + + def test_shared_room_list_public(self): + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + def test_shared_room_list_private(self): + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + def test_shared_room_list_mixed(self): + """ + The shared room list between two users should contain both public and private + rooms. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) + self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) + self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) + self.helper.join(room_public, user=u2, tok=u2_token) + self.helper.join(room_private, user=u1, tok=u1_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertTrue(room_public in channel.json_body["joined"]) + self.assertTrue(room_private in channel.json_body["joined"]) + + def test_shared_room_list_after_leave(self): + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + request, channel = self._get_shared_rooms(u2_token, u1) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 99eb477149..6850c666be 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json(destination, path, ignore_backoff=False, **kwargs): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - def post_json(destination, path, data): + async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, "/_matrix/key/v2/query", diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1ca648ef2b..f4f3e56777 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -12,22 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import os import shutil import tempfile from binascii import unhexlify from io import BytesIO from typing import Optional +from urllib import parse from mock import Mock -from six.moves.urllib import parse import attr -import PIL.Image as Image from parameterized import parameterized_class +from PIL import Image as Image +from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable @@ -79,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase): # This uses a real blocking threadpool so we have to wait for it to be # actually done :/ - x = self.media_storage.ensure_media_is_in_local_cache(file_info) + x = defer.ensureDeferred( + self.media_storage.ensure_media_is_in_local_cache(file_info) + ) # Hotloop until the threadpool does its job... self.wait_on_thread(x) @@ -232,7 +233,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(len(self.fetches), 1) self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual( - self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id + self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id ) self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2826211f32..c00a7b9114 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -12,8 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os +import re + +from mock import patch import attr @@ -29,7 +32,7 @@ from tests.server import FakeTransport @attr.s -class FakeResponse(object): +class FakeResponse: version = attr.ib() code = attr.ib() phrase = attr.ib() @@ -40,7 +43,7 @@ class FakeResponse(object): @property def request(self): @attr.s - class FakeTransport(object): + class FakeTransport: absoluteURI = self.absoluteURI return FakeTransport() @@ -108,7 +111,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups = {} - class Resolver(object): + class Resolver: def resolveHostName( _self, resolutionReceiver, @@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() def test_cache_returns_correct_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://matrix.org", shorthand=False @@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def test_non_ascii_preview_httpequiv(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"<html><head>" @@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"<html><head>" @@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_overlong_title(self): - self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] end_content = ( b"<html><head>" @@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ IP addresses can be previewed directly. """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( "GET", "url_preview?url=http://example.com", shorthand=False @@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Hardcode the URL resolving to the IP we want. self.lookups["example.com"] = [ (IPv4Address, "1.1.1.2"), - (IPv4Address, "8.8.8.8"), + (IPv4Address, "10.1.2.3"), ] request, channel = self.make_request( @@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Accept-Language header is sent to the remote server """ - self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server request, channel = self.make_request( @@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase): ), server.data, ) + + def test_oembed_photo(self): + """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result).encode("utf-8") + + end_content = ( + b"<html><head>" + b"<title>Some Title</title>" + b'<meta property="og:description" content="hi" />' + b"</head></html>" + ) + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(oembed_content),) + + oembed_content + ) + + self.pump() + + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + ) + + def test_oembed_rich(self): + """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" + # Route the HTTP version to an HTTP endpoint so that the tests work. + with patch.dict( + "synapse.rest.media.v1.preview_url_resource._oembed_patterns", + { + re.compile( + r"http://twitter\.com/.+/status/.+" + ): "http://publish.twitter.com/oembed", + }, + clear=True, + ): + + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "rich", + "html": "<div>Content Preview</div>", + } + end_content = json.dumps(result).encode("utf-8") + + request, channel = self.make_request( + "GET", + "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + {"og:title": None, "og:description": "Content Preview"}, + ) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py new file mode 100644 index 0000000000..2d021f6565 --- /dev/null +++ b/tests/rest/test_health.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.rest.health import HealthResource + +from tests import unittest + + +class HealthCheckTests(unittest.HomeserverTestCase): + def setUp(self): + super().setUp() + + # replace the JsonResource with a HealthResource. + self.resource = HealthResource() + + def test_health(self): + request, channel = self.make_request("GET", "/health", shorthand=False) + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/server.py b/tests/server.py index 1644710aa0..48e45c6c8b 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,8 +2,6 @@ import json import logging from io import BytesIO -from six import text_type - import attr from zope.interface import implementer @@ -37,7 +35,7 @@ class TimedOutException(Exception): @attr.s -class FakeChannel(object): +class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). @@ -174,7 +172,7 @@ def make_request( if not path.startswith(b"/"): path = b"/" + path - if isinstance(content, text_type): + if isinstance(content, str): content = content.encode("utf8") site = FakeSite() @@ -239,11 +237,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def __init__(self): self.threadpool = ThreadPool(self) + self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} @implementer(IResolverSimple) - class FakeResolver(object): + class FakeResolver: def getHostByName(self, name, timeout=None): if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) @@ -270,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def getThreadPool(self): return self.threadpool + def add_tcp_client_callback(self, host, port, callback): + """Add a callback that will be invoked when we receive a connection + attempt to the given IP/port using `connectTCP`. + + Note that the callback gets run before we return the connection to the + client, which means callbacks cannot block while waiting for writes. + """ + self._tcp_callbacks[(host, port)] = callback + + def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): + """Fake L{IReactorTCP.connectTCP}. + """ + + conn = super().connectTCP( + host, port, factory, timeout=timeout, bindAddress=None + ) + + callback = self._tcp_callbacks.get((host, port)) + if callback: + callback() + + return conn + class ThreadPool: """ @@ -349,7 +371,7 @@ def get_clock(): @attr.s(cmp=False) -class FakeTransport(object): +class FakeTransport: """ A twisted.internet.interfaces.ITransport implementation which sends all its data straight into an IProtocol object: it exists to connect two IProtocols together. @@ -488,7 +510,7 @@ class FakeTransport(object): try: self.other.dataReceived(to_write) except Exception as e: - logger.warning("Exception writing to protocol: %s", e) + logger.exception("Exception writing to protocol: %s", e) return self.buffer = self.buffer[len(to_write) :] diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 99908edba3..973338ea71 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -66,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(1000) + side_effect=lambda user_id: make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( return_value=defer.succeed(Mock()) @@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) + self._rlsn._store.get_tags_for_room = Mock( + side_effect=lambda user_id, room_id: make_awaitable({}) + ) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): @@ -101,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -119,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -155,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + side_effect=lambda user_id: make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -214,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -258,10 +261,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=1000) + self.store.get_monthly_active_count = Mock( + side_effect=lambda: make_awaitable(1000) + ) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(1000) + side_effect=lambda user_id: make_awaitable(1000) ) # Call the function multiple times to ensure we only send the notice once @@ -275,7 +280,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) - token = self.get_success(self.event_source.get_current_token()) + token = self.event_source.get_current_token() events, _ = self.get_success( self.store.get_recent_events_for_room( room_id, limit=100, end_token=token.room_key diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index a44960203e..ad9bbef9d2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -14,11 +14,12 @@ # limitations under the License. import itertools - -from six.moves import zip +from typing import List import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event @@ -43,7 +44,12 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN} ORIGIN_SERVER_TS = 0 -class FakeEvent(object): +class FakeClock: + def sleep(self, msec): + return defer.succeed(None) + + +class FakeEvent: """A fake event we use as a convenience. NOTE: Again as a convenience we use "node_ids" rather than event_ids to @@ -419,6 +425,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + FakeClock(), ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], @@ -426,7 +433,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - state_before = self.successResultOf(state_d) + state_before = self.successResultOf(defer.ensureDeferred(state_d)) state_after = dict(state_before) if fake_event.state_key is not None: @@ -567,6 +574,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + FakeClock(), ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], @@ -574,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(self.event_map), ) - state = self.successResultOf(state_d) + state = self.successResultOf(defer.ensureDeferred(state_d)) self.assert_dict(self.expected_combined_state, state) @@ -587,7 +595,7 @@ def pairwise(iterable): @attr.s -class TestStateResolutionStore(object): +class TestStateResolutionStore: event_map = attr.ib() def get_events(self, event_ids, allow_rejected=False): @@ -601,9 +609,11 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + return defer.succeed( + {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + ) - def _get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids: List[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -615,10 +625,10 @@ class TestStateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth + event_ids: The event IDs of the events to fetch the auth chain for. Must be state events. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + List of event IDs of the auth chain. """ # Simple DFS for auth chain @@ -641,4 +651,4 @@ class TestStateResolutionStore(object): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return set(chains[0]).union(*chains[1:]) - common + return defer.succeed(set(chains[0]).union(*chains[1:]) - common) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 5a50e4fdd4..f5afed017c 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase): class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): - class A(object): + class A: @cached() def func(self, key): return key @@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_hit(self): callcount = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_invalidate(self): callcount = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEquals(callcount[0], 2) def test_invalidate_missing(self): - class A(object): + class A: @cached() def func(self, key): return key @@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_max_entries(self): callcount = [0] - class A(object): + class A: @cached(max_entries=10) def func(self, key): callcount[0] += 1 @@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): d = defer.succeed(123) - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): callcount = [0] callcount2 = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): callcount = [0] callcount2 = [0] - class A(object): + class A: @cached(max_entries=2) def func(self, key): callcount[0] += 1 @@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): callcount = [0] callcount2 = [0] - class A(object): + class A: @cached() def func(self, key): callcount[0] += 1 @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ef296e7dab..cb808d4de4 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,13 +24,14 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database, make_conn from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservice_state_none(self): service = Mock(id="999") - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks def test_get_appservice_state_up(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks @@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks def test_get_appservices_by_state_none(self): - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -234,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_create_appservice_txn_first(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -246,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_last_txn(service.id, 9643) # AS is falling behind yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -256,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] yield self._set_last_txn(service.id, 9643) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -277,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(self.as_list[2]["id"], 10, events) yield self._insert_txn(self.as_list[3]["id"], 9643, events) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -289,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): txn_id = 1 yield self._insert_txn(service.id, txn_id, events) - yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) + yield defer.ensureDeferred( + self.store.complete_appservice_txn(txn_id=txn_id, service=service) + ) res = yield self.db_pool.runQuery( self.engine.convert_param_style( @@ -315,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): txn_id = 5 yield self._set_last_txn(service.id, 4) yield self._insert_txn(service.id, txn_id, events) - yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) + yield defer.ensureDeferred( + self.store.complete_appservice_txn(txn_id=txn_id, service=service) + ) res = yield self.db_pool.runQuery( self.engine.convert_param_style( @@ -339,7 +360,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks @@ -349,14 +370,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) @@ -366,8 +387,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) @@ -379,8 +400,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) self.assertEquals( @@ -391,7 +412,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TestTransactionStore, self).__init__(database, db_conn, hs) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 940b166129..02aae1c13d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,7 +1,5 @@ from mock import Mock -from twisted.internet import defer - from synapse.storage.background_updates import BackgroundUpdater from tests import unittest @@ -9,7 +7,9 @@ from tests import unittest class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + self.updates = ( + self.hs.get_datastore().db_pool.updates + ) # type: BackgroundUpdater # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -29,18 +29,17 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() self.get_success( - store.db.simple_insert( + store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) ) # first step: make a bit of progress - @defer.inlineCallbacks - def update(progress, count): - yield self.clock.sleep((count * duration_ms) / 1000) + async def update(progress, count): + await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield store.db.runInteraction( + await store.db_pool.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", @@ -65,13 +64,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # second step: complete the update # we should now get run with a much bigger number of items to update - @defer.inlineCallbacks - def update(progress, count): + async def update(progress, count): self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, target_background_update_duration_ms / duration_ms, places=0, ) - yield self.updates._end_background_update("test_update") + await self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 278961c331..40ba652248 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,11 +21,11 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest -from tests.utils import TestHomeServer +from tests.utils import TestHomeServer, default_config class SQLBaseStoreTestCase(unittest.TestCase): @@ -49,10 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.db_pool.runWithConnection = runWithConnection - config = Mock() - config._disable_native_upserts = True - config.caches = Mock() - config.caches.event_cache_size = 1 + config = default_config(name="test", parse=True) hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} @@ -60,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool self.datastore = SQLBaseStore(db, None, hs) @@ -69,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( - table="tablename", values={"columname": "Value"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", values={"columname": "Value"} + ) ) self.mock_txn.execute.assert_called_with( @@ -81,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + ) ) self.mock_txn.execute.assert_called_with( @@ -96,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db.simple_select_one_onecol( - table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + value = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one_onecol( + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + ) ) self.assertEquals("Value", value) @@ -110,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db.simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"], + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"], + ) ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) @@ -126,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db.simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True, + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True, + ) ) self.assertFalse(ret) @@ -141,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db.simple_select_list( - table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_list( + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ) ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) @@ -154,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"}, + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"}, + ) ) self.mock_txn.execute.assert_called_with( @@ -169,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + ) ) self.mock_txn.execute.assert_called_with( @@ -184,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_delete_one( - table="tablename", keyvalues={"keycol": "Go away"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_one( + table="tablename", keyvalues={"keycol": "Go away"} + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 43425c969a..080761d1d2 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Make sure we don't clash with in progress updates. self.assertTrue( - self.store.db.updates._all_done, "Background updates are still ongoing" + self.store.db_pool.updates._all_done, "Background updates are still ongoing" ) schema_path = os.path.join( prepare_database.dir_path, - "data_stores", + "databases", "main", "schema", "delta", @@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "test_delete_forward_extremities", run_delta_file ) ) # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_soft_failed_extremities_handled_correctly(self): @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() @@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Pump the reactor repeatedly so that the background updates have a # chance to run. - self.pump(10 * 60) + self.pump(20) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) @@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ "3" ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() # All entries within time frame self.assertEqual( @@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): 3, ) # Oldest room to expire - self.pump(1) + self.pump(1.01) self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() self.assertEqual( len( diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 3b483bc7f0..370c247e16 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -16,13 +16,12 @@ from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -86,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -117,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): user_id = "@user:server" self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(lots_of_users) + side_effect=lambda: make_awaitable(lots_of_users) ) self.get_success( self.store.insert_client_ip( @@ -204,10 +203,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -225,7 +224,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store.db.simple_update( + self.store.db_pool.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -252,7 +251,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -263,14 +262,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # We should now get the correct result again @@ -293,10 +292,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -315,7 +314,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -341,7 +340,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index c2539b353a..34ae8c9da7 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -34,9 +34,11 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_store_new_device(self): - yield self.store.store_device("user_id", "device_id", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name") + ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self.store.store_device("user_id", "device1", "display_name 1") - yield self.store.store_device("user_id", "device2", "display_name 2") - yield self.store.store_device("user_id2", "device3", "display_name 3") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) - res = yield self.store.get_devices_by_user("user_id") + res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"] + yield defer.ensureDeferred( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "somehost", -1, limit=100 + now_stream_id, device_updates = yield defer.ensureDeferred( + self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) # Check original device_ids are contained within these updates @@ -99,29 +107,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_device(self): - yield self.store.store_device("user_id", "device_id", "display_name 1") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name 1") + ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield self.store.update_device("user_id", "device_id") - res = yield self.store.get_device("user_id", "device_id") + yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update - yield self.store.update_device( - "user_id", "device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "device_id", new_display_name="display_name 2" + ) ) # check it worked - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) @defer.inlineCallbacks def test_update_unknown_device(self): with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ) ) self.assertEqual(404, cm.exception.code) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 4e128e1047..da93ca3980 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -34,35 +34,53 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_room_to_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_aliases_for_room(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks def test_alias_to_room(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - (yield self.store.get_association_from_room_alias(self.alias)), + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ), ) @defer.inlineCallbacks def test_delete_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) - room_id = yield self.store.delete_room_alias(self.alias) + room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (yield self.store.get_association_from_room_alias(self.alias)) + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 398d546280..3fc4bb13b6 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -30,11 +30,15 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) - yield self.store.set_e2e_device_keys("user", "device", now, json) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -45,14 +49,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) - changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + changed = yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) self.assertTrue(changed) # If we try to upload the same key then we should be told nothing # changed - changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + changed = yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) self.assertFalse(changed) @defer.inlineCallbacks @@ -60,10 +68,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.set_e2e_device_keys("user", "device", now, json) - yield self.store.store_device("user", "device", "display_name") + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) + yield defer.ensureDeferred( + self.store.store_device("user", "device", "display_name") + ) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -75,18 +89,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): def test_multiple_devices(self): now = 1470174257070 - yield self.store.store_device("user1", "device1", None) - yield self.store.store_device("user1", "device2", None) - yield self.store.store_device("user2", "device1", None) - yield self.store.store_device("user2", "device2", None) + yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) - yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) - yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) - yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) - yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) + ) - res = yield self.store.get_e2e_device_keys( - (("user1", "device1"), ("user2", "device2")) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys_for_cs_api( + (("user1", "device1"), ("user2", "device2")) + ) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 3aeec0dc0f..d4c3b867e3 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) for i in range(0, 20): - self.get_success(self.store.db.runInteraction("insert", insert_event, i)) + self.get_success( + self.store.db_pool.runInteraction("insert", insert_event, i) + ) # this should get the last ten r = self.get_success(self.store.get_prev_events_for_room(room_id)) @@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for i in range(0, 20): self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room1) + self.store.db_pool.runInteraction("insert", insert_event, i, room1) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room2) + self.store.db_pool.runInteraction("insert", insert_event, i, room2) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room3) + self.store.db_pool.runInteraction("insert", insert_event, i, room3) ) # Test simple case @@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): depth = depth_map[event_id] - self.store.db.simple_insert_txn( + self.store.db_pool.simple_insert_txn( txn, table="events", values={ @@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) - self.store.db.simple_insert_many_txn( + self.store.db_pool.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for event_id in auth_graph: next_stream_ordering += 1 self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "insert", insert_event, event_id, next_stream_ordering ) ) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b85004e5..949846fe33 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b45bc9c115..c0595963dd 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks @@ -56,12 +60,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.db.runInteraction( - "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + counts = yield defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + ) ) self.assertEquals( counts, - {"notify_count": noitf_count, "highlight_count": highlight_count}, + { + "notify_count": noitf_count, + "unread_count": 0, # Unread counts are tested in the sync tests. + "highlight_count": highlight_count, + }, ) @defer.inlineCallbacks @@ -72,28 +82,36 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + yield defer.ensureDeferred( + self.store.add_push_actions_to_staging( + event.event_id, {user_id: action}, False, + ) ) - yield self.store.db.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], + yield defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", + self.persist_events_store._set_push_actions_for_event_and_users_txn, + [(event, None)], + [(event, None)], + ) ) def _rotate(stream): - return self.store.db.runInteraction( - "", self.store._rotate_notifs_before_txn, stream + return defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", self.store._rotate_notifs_before_txn, stream + ) ) def _mark_read(stream, depth): - return self.store.db.runInteraction( - "", - self.store._remove_old_push_actions_before_txn, - room_id, - user_id, - stream, + return defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", + self.store._remove_old_push_actions_before_txn, + room_id, + user_id, + stream, + ) ) yield _assert_counts(0, 0) @@ -117,8 +135,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" + yield defer.ensureDeferred( + self.store.db_pool.simple_delete( + table="event_push_actions", keyvalues={"1": 1}, desc="" + ) ) yield _assert_counts(1, 0) @@ -136,33 +156,43 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db.simple_insert( - "events", - { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }, + return defer.ensureDeferred( + self.store.db_pool.simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) ) # start with the base case where there are no events in the table - r = yield self.store.find_first_stream_ordering_after_ts(11) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(11) + ) self.assertEqual(r, 0) # now with one event yield add_event(2, 10) - r = yield self.store.find_first_stream_ordering_after_ts(9) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(9) + ) self.assertEqual(r, 2) - r = yield self.store.find_first_stream_ordering_after_ts(10) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(10) + ) self.assertEqual(r, 2) - r = yield self.store.find_first_stream_ordering_after_ts(11) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(11) + ) self.assertEqual(r, 3) # add a bunch of dummy events to the events table @@ -175,25 +205,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ): yield add_event(stream_ordering, ts) - r = yield self.store.find_first_stream_ordering_after_ts(110) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(110) + ) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) # 4 and 5 are both after 120: we want 4 rather than 5 - r = yield self.store.find_first_stream_ordering_after_ts(120) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(120) + ) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) - r = yield self.store.find_first_stream_ordering_after_ts(129) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(129) + ) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) # check we can get the last event - r = yield self.store.find_first_stream_ordering_after_ts(140) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(140) + ) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) # off the end - r = yield self.store.find_first_stream_ordering_after_ts(160) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(160) + ) self.assertEqual(r, 21) # check we can find an event at ordering zero yield add_event(0, 5) - r = yield self.store.find_first_stream_ordering_after_ts(1) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(1) + ) self.assertEqual(r, 0) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 55e9ecf264..f0a8e32f1e 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -14,7 +14,7 @@ # limitations under the License. -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.unittest import HomeserverTestCase @@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db = self.store.db # type: Database + self.db_pool = self.store.db_pool # type: DatabasePool - self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) def _setup_db(self, txn): txn.execute("CREATE SEQUENCE foobar_seq") @@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def _create(conn): return MultiWriterIdGenerator( conn, - self.db, + self.db_pool, instance_name=instance_name, table="foobar", instance_column="instance_name", @@ -55,9 +55,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): sequence_name="foobar_seq", ) - return self.get_success(self.db.runWithConnection(_create)) + return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + def _insert(txn): for _ in range(number): txn.execute( @@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): (instance_name,), ) - self.get_success(self.db.runInteraction("test_single_instance", _insert)) + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def _insert_row_with_id(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id, updating + the postgres sequence position to match. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + + self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) def test_empty(self): """Test an ID generator against an empty database gives sensible @@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -98,12 +115,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_multi_instance(self): """Test that reads and writes from multiple processes are handled @@ -116,8 +133,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen = self._create_id_generator("second") self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token("first"), 3) - self.assertEqual(first_id_gen.get_current_token("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -166,7 +183,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -176,9 +193,179 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - self.get_success(self.db.runInteraction("test", _get_next_txn)) + self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + + def test_get_persisted_upto_position(self): + """Test that `get_persisted_upto_position` correctly tracks updates to + positions. + """ + + # The following tests are a bit cheeky in that we notify about new + # positions via `advance` without *actually* advancing the postgres + # sequence. + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + id_gen = self._create_id_generator("first") + + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + + # Min is 3 and there is a gap between 5, so we expect it to be 3. + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # We advance "first" straight to 6. Min is now 5 but there is no gap so + # we expect it to be 6 + id_gen.advance("first", 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 6) + + # No gap, so we expect 7. + id_gen.advance("second", 7) + self.assertEqual(id_gen.get_persisted_upto_position(), 7) + + # We haven't seen 8 yet, so we expect 7 still. + id_gen.advance("second", 9) + self.assertEqual(id_gen.get_persisted_upto_position(), 7) + + # Now that we've seen 7, 8 and 9 we can got straight to 9. + id_gen.advance("first", 8) + self.assertEqual(id_gen.get_persisted_upto_position(), 9) + + # Jump forward with gaps. The minimum is 11, even though we haven't seen + # 10 we know that everything before 11 must be persisted. + id_gen.advance("first", 11) + id_gen.advance("second", 15) + self.assertEqual(id_gen.get_persisted_upto_position(), 11) + + def test_get_persisted_upto_position_get_next(self): + """Test that `get_persisted_upto_position` correctly tracks updates to + positions when `get_next` is called. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + id_gen = self._create_id_generator("first") + + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + with self.get_success(id_gen.get_next()) as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.assertEqual(id_gen.get_persisted_upto_position(), 6) + + # We assume that so long as `get_next` does correctly advance the + # `persisted_upto_position` in this case, then it will be correct in the + # other cases that are tested above (since they'll hit the same code). + + +class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): + """Tests MultiWriterIdGenerator that produce *negative* stream IDs. + """ + + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + positive=False, + ) + + return self.get_success(self.db_pool.runWithConnection(_create)) + + def _insert_row(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + id_gen = self._create_id_generator() + + with self.get_success(id_gen.get_next()) as stream_id: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -1}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) + self.assertEqual(id_gen.get_persisted_upto_position(), -1) + + with self.get_success(id_gen.get_next_mult(3)) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -4}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(id_gen.get_persisted_upto_position(), -4) + + # Test loading from DB by creating a second ID gen + second_id_gen = self._create_id_generator() + + self.assertEqual(second_id_gen.get_positions(), {"master": -4}) + self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) + + def test_multiple_instance(self): + """Tests that having multiple instances that get advanced over + federation works corretly. + """ + id_gen_1 = self._create_id_generator("first") + id_gen_2 = self._create_id_generator("second") + + with self.get_success(id_gen_1.get_next()) as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) + + with self.get_success(id_gen_2.get_next()) as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index ab0df5ea93..7e7f1286d9 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -34,12 +34,16 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): - yield self.store.register_user(self.user.to_string(), "pass") - yield self.store.create_profile(self.user.localpart) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield defer.ensureDeferred( + self.store.register_user(self.user.to_string(), "pass") + ) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname) + ) - users, total = yield self.store.get_users_paginate( - 0, 10, name="bc", guests=False + users, total = yield defer.ensureDeferred( + self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9c04e92577..9870c74883 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX why are we doing this here? this function is only run at startup # so it is odd to re-run it here. self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) ) @@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ] self.hs.config.mau_limits_reserved_threepids = threepids - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -293,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) @@ -333,7 +344,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9b6f7211ae..3fd0a38cf5 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,23 +33,36 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ), ) @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.u_frank.localpart, "http://my.site/here" + ) ) self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ), ) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index b9fafaa1a6..918387733b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import NotFoundError from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -44,28 +47,19 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = store.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) + event = self.get_success( + store.get_topological_token_for_event(last["event_id"]) + ) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) - self.pump() - self.assertEqual(self.successResultOf(purge), None) - - # Try and get the events - get_first = store.get_event(first["event_id"]) - get_second = store.get_event(second["event_id"]) - get_third = store.get_event(third["event_id"]) - get_last = store.get_event(last["event_id"]) - self.pump() + self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. - self.failureResultOf(get_first) - self.failureResultOf(get_second) - self.failureResultOf(get_third) - self.successResultOf(get_last) + self.get_failure(store.get_event(first["event_id"]), NotFoundError) + self.get_failure(store.get_event(second["event_id"]), NotFoundError) + self.get_failure(store.get_event(third["event_id"]), NotFoundError) + self.get_success(store.get_event(last["event_id"])) def test_purge_wont_delete_extrems(self): """ @@ -80,28 +74,21 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = storage.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) + event = self.get_success( + storage.get_topological_token_for_event(last["event_id"]) + ) event = "t{}-{}".format( *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) - self.pump() - - # Nothing is deleted. - self.successResultOf(get_first) - self.successResultOf(get_second) - self.successResultOf(get_third) - self.successResultOf(get_last) + self.get_success(storage.get_event(first["event_id"])) + self.get_success(storage.get_event(second["event_id"])) + self.get_success(storage.get_event(third["event_id"])) + self.get_success(storage.get_event(last["event_id"])) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index db3667dc43..1ea35d60c1 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): - built_event = yield self._base_builder.build(prev_event_ids) + built_event = yield defer.ensureDeferred( + self._base_builder.build(prev_event_ids) + ) built_event._event_id = self._event_id built_event._dict["event_id"] = self._event_id @@ -249,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase): def room_id(self): return self._base_builder.room_id + @property + def type(self): + return self._base_builder.type + event_1, context_1 = self.get_success( self.event_creation_handler.create_new_client_event( EventIdManglingBuilder( @@ -341,7 +347,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -359,7 +365,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 71a40a0a49..6b582771fe 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes +from synapse.api.errors import ThreepidValidationError from tests import unittest from tests.utils import setup_test_homeserver @@ -36,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_register(self): - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.assertEquals( { @@ -52,17 +53,21 @@ class RegistrationStoreTestCase(unittest.TestCase): "user_type": None, "deactivated": 0, }, - (yield self.store.get_user_by_id(self.user_id)), + (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), ) @defer.inlineCallbacks def test_add_tokens(self): - yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) - result = yield self.store.get_user_by_access_token(self.tokens[1]) + result = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[1]) + ) self.assertDictContainsSubset( {"name": self.user_id, "device_id": self.device_id}, result @@ -73,31 +78,41 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) ) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) # now delete some - yield self.store.user_delete_access_tokens( - self.user_id, device_id=self.device_id + yield defer.ensureDeferred( + self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) ) # check they were deleted - user = yield self.store.get_user_by_access_token(self.tokens[1]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[1]) + ) self.assertIsNone(user, "access token was not deleted by device_id") # check the one not associated with the device was not deleted - user = yield self.store.get_user_by_access_token(self.tokens[0]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[0]) + ) self.assertEqual(self.user_id, user["name"]) # now delete the rest - yield self.store.user_delete_access_tokens(self.user_id) + yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) - user = yield self.store.get_user_by_access_token(self.tokens[0]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[0]) + ) self.assertIsNone(user, "access token was not deleted without device_id") @defer.inlineCallbacks @@ -105,14 +120,48 @@ class RegistrationStoreTestCase(unittest.TestCase): TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = yield self.store.is_support_user(None) + res = yield defer.ensureDeferred(self.store.is_support_user(None)) self.assertFalse(res) - yield self.store.register_user(user_id=TEST_USER, password_hash=None) - res = yield self.store.is_support_user(TEST_USER) + yield defer.ensureDeferred( + self.store.register_user(user_id=TEST_USER, password_hash=None) + ) + res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) self.assertFalse(res) - yield self.store.register_user( - user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT + yield defer.ensureDeferred( + self.store.register_user( + user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT + ) ) - res = yield self.store.is_support_user(SUPPORT_USER) + res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) + + @defer.inlineCallbacks + def test_3pid_inhibit_invalid_validation_session_error(self): + """Tests that enabling the configuration option to inhibit 3PID errors on + /requestToken also inhibits validation errors caused by an unknown session ID. + """ + + # Check that, with the config setting set to false (the default value), a + # validation error is caused by the unknown session ID. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Unknown session_id", e) + + # Set the config setting to true. + self.store._ignore_unknown_session_error = True + + # Check that now the validation error is caused by the token not matching. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3b78d48896..bc8400f240 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -52,7 +54,13 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield self.store.get_room(self.room.to_string())), + (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), + ) + + @defer.inlineCallbacks + def test_get_room_unknown_room(self): + self.assertIsNone( + (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) ) @defer.inlineCallbacks @@ -63,7 +71,21 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - (yield self.store.get_room_with_stats(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats(self.room.to_string()) + ) + ), + ) + + @defer.inlineCallbacks + def test_get_room_with_stats_unknown_room(self): + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats("!uknown:test") + ) + ), ) @@ -80,17 +102,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks @@ -101,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( @@ -117,7 +145,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5dd46005e6..12ccc1f53e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump() self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) @@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Initialises to 1 -- itself self.assertEqual(self.store._known_servers_count, 1) - self.pump(20) + self.pump() # No rooms have been joined, so technically the SQL returns 0, but it # will still say it knows about itself. @@ -111,25 +111,29 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump(1) # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) def test_get_joined_users_from_context(self): room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN + bob_event = self.get_success( + event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) ) # first, create a regular event - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) ) users = self.get_success( @@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Regression test for #7376: create a state event whose key matches bob's # user_id, but which is *not* a membership event, and persist that; then check # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, + non_member_event = self.get_success( + event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) ) - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) ) users = self.get_success( self.store.get_joined_users_from_context(event, context) @@ -171,20 +179,20 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -195,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b88308ff4..8bd12fa847 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -64,11 +66,13 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 6a545d2eb0..738e912468 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -31,16 +31,24 @@ class UserDirectoryStoreTestCase(unittest.TestCase): # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. - yield self.store.update_profile_in_user_dir(ALICE, "alice", None) - yield self.store.update_profile_in_user_dir(BOB, "bob", None) - yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(ALICE, "alice", None) + ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BOB, "bob", None) + ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BOBBY, "bobby", None) + ) + yield defer.ensureDeferred( + self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) + ) @defer.inlineCallbacks def test_search_user_dir(self): # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) self.assertDictEqual( @@ -51,7 +59,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): def test_search_user_dir_all_users(self): self.hs.config.user_directory_search_all_users = True try: - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) self.assertDictEqual( diff --git a/tests/test_federation.py b/tests/test_federation.py index c662195eec..27a7fc9ed7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,7 +1,23 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock -from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed +from twisted.internet.defer import succeed +from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID @@ -10,6 +26,7 @@ from synapse.util.retryutils import NotRetryingDestination from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver +from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -26,24 +43,19 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, None, None) + our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room_deferred = ensureDeferred( + self.room_id = self.get_success( room_creator.create_room( - our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - ) - self.reactor.advance(0.1) - self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + )[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + most_recent = self.get_success( + self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -73,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Send the join, it should return None (which is not an error) - d = ensureDeferred( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + self.assertEqual( + self.get_success( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + ), + None, ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( - self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - )[0], + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) @@ -95,7 +106,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} @@ -103,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) + most_recent = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) )[0] # Now lie about an event @@ -124,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) with LoggingContext(request="lying_event"): - d = ensureDeferred( + failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True - ) + ), + FederationError, ) - # Step the reactor, so the database fetches come back - self.reactor.advance(1) - # on_receive_pdu should throw an error - failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( @@ -144,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and @@ -173,7 +181,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -218,23 +226,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { + return_value=succeed( + { "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) # Resync the device list. diff --git a/tests/test_mau.py b/tests/test_mau.py index 49667ed7f4..654a6fa42d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -166,7 +166,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.do_sync_for_user(token5) self.do_sync_for_user(token6) - # But old user cant + # But old user can't with self.assertRaises(SynapseError) as cm: self.do_sync_for_user(token1) diff --git a/tests/test_server.py b/tests/test_server.py index e9a43b1e45..655c918a15 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -12,31 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import re -from six import StringIO - from twisted.internet.defer import Deferred -from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource -from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, RedirectException, SynapseError -from synapse.http.server import ( - DirectServeResource, - JsonResource, - OptionsResource, - wrap_html_request_handler, -) -from synapse.http.site import SynapseSite, logger +from synapse.config.server import parse_listener_def +from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource +from synapse.http.site import SynapseSite from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock from tests import unittest from tests.server import ( - FakeTransport, ThreadedMemoryReactorClock, make_request, render, @@ -168,6 +157,28 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def test_head_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + + def _callback(request, **kwargs): + return 200, {"result": True} + + res = JsonResource(self.homeserver) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet", + ) + + # The path was registered as GET, but this is a HEAD request. + request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + self.assertNotIn("body", channel.result) + class OptionsResourceTests(unittest.TestCase): def setUp(self): @@ -189,7 +200,13 @@ class OptionsResourceTests(unittest.TestCase): request.prepath = [] # This doesn't get set properly by make_request. # Create a site and query for the resource. - site = SynapseSite("test", "site_tag", {}, self.resource, "1.0") + site = SynapseSite( + "test", + "site_tag", + parse_listener_def({"type": "http", "port": 0}), + self.resource, + "1.0", + ) request.site = site resource = site.getResourceFor(request) @@ -198,10 +215,10 @@ class OptionsResourceTests(unittest.TestCase): return channel def test_unknown_options_request(self): - """An OPTIONS requests to an unknown URL still returns 200 OK.""" + """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.result["body"], b"{}") + self.assertEqual(channel.result["code"], b"204") + self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added self.assertTrue( @@ -218,10 +235,10 @@ class OptionsResourceTests(unittest.TestCase): ) def test_known_options_request(self): - """An OPTIONS requests to an known URL still returns 200 OK.""" + """An OPTIONS requests to an known URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/res/") - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.result["body"], b"{}") + self.assertEqual(channel.result["code"], b"204") + self.assertNotIn("body", channel.result) # Ensure the correct CORS headers have been added self.assertTrue( @@ -250,18 +267,17 @@ class OptionsResourceTests(unittest.TestCase): class WrapHtmlRequestHandlerTests(unittest.TestCase): - class TestResource(DirectServeResource): + class TestResource(DirectServeHtmlResource): callback = None - @wrap_html_request_handler async def _async_render_GET(self, request): - return await self.callback(request) + await self.callback(request) def setUp(self): self.reactor = ThreadedMemoryReactorClock() def test_good_response(self): - def callback(request): + async def callback(request): request.write(b"response") request.finish() @@ -281,7 +297,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): with the right location. """ - def callback(request, **kwargs): + async def callback(request, **kwargs): raise RedirectException(b"/look/an/eagle", 301) res = WrapHtmlRequestHandlerTests.TestResource() @@ -301,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): returned too """ - def callback(request, **kwargs): + async def callback(request, **kwargs): e = RedirectException(b"/no/over/there", 304) e.cookies.append(b"session=yespls") raise e @@ -319,51 +335,18 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) + def test_head_request(self): + """A head request should work by being turned into a GET request.""" -class SiteTestCase(unittest.HomeserverTestCase): - def test_lose_connection(self): - """ - We log the URI correctly redacted when we lose the connection. - """ + async def callback(request): + request.write(b"response") + request.finish() - class HangingResource(Resource): - """ - A Resource that strategically hangs, as if it were processing an - answer. - """ + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback - def render(self, request): - return NOT_DONE_YET - - # Set up a logging handler that we can inspect afterwards - output = StringIO() - handler = logging.StreamHandler(output) - logger.addHandler(handler) - old_level = logger.level - logger.setLevel(10) - self.addCleanup(logger.setLevel, old_level) - self.addCleanup(logger.removeHandler, handler) - - # Make a resource and a Site, the resource will hang and allow us to - # time out the request while it's 'processing' - base_resource = Resource() - base_resource.putChild(b"", HangingResource()) - site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") - - server = site.buildProtocol(None) - client = AccumulatingProtocol() - client.makeConnection(FakeTransport(server, self.reactor)) - server.makeConnection(FakeTransport(client, self.reactor)) - - # Send a request with an access token that will get redacted - server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") - self.pump() - - # Lose the connection - e = Failure(Exception("Failed123")) - server.connectionLost(e) - handler.flush() - - # Our access token is redacted and the failure reason is logged. - self.assertIn("/?access_token=<redacted>", output.getvalue()) - self.assertIn("Failed123", output.getvalue()) + request, channel = make_request(self.reactor, b"HEAD", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + self.assertNotIn("body", channel.result) diff --git a/tests/test_state.py b/tests/test_state.py index 66f22f6813..2d58467932 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -71,7 +71,7 @@ def create_event( return event -class StateGroupStore(object): +class StateGroupStore: def __init__(self): self._event_to_state_group = {} self._group_to_state = {} @@ -80,16 +80,16 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups_ids(self, room_id, event_ids): + async def get_state_groups_ids(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) if group: groups[group] = self._group_to_state[group] - return defer.succeed(groups) + return groups - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): state_group = self._next_group @@ -99,15 +99,15 @@ class StateGroupStore(object): return state_group - def get_events(self, event_ids, **kwargs): + async def get_events(self, event_ids, **kwargs): return { e_id: self._event_id_to_event[e_id] for e_id in event_ids if e_id in self._event_id_to_event } - def get_state_group_delta(self, name): - return None, None + async def get_state_group_delta(self, name): + return (None, None) def register_events(self, events): for e in events: @@ -119,7 +119,7 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group - def get_room_version_id(self, room_id): + async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier @@ -129,7 +129,7 @@ class DictObj(dict): self.__dict__ = self -class Graph(object): +class Graph: def __init__(self, nodes, edges): events = {} clobbered = set(events.keys()) @@ -202,14 +202,16 @@ class StateTestCase(unittest.TestCase): context_store = {} # type: dict[str, EventContext] for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -244,7 +246,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -253,7 +257,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -300,7 +304,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -310,7 +316,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -373,7 +379,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -383,7 +391,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) @@ -411,12 +419,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -434,12 +444,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -462,18 +474,20 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( - prev_event_id, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state}, + group_name = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( {e.event_id for e in old_state}, set(current_state_ids.values()) @@ -494,18 +508,20 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( - prev_event_id, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state}, + group_name = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) @@ -544,7 +560,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -586,7 +602,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -641,7 +657,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -669,29 +685,35 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) + @defer.inlineCallbacks def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = self.store.store_state_group( - prev_event_id_1, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state_1}, + sg1 = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id_1, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state_1}, + ) ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = self.store.store_state_group( - prev_event_id_2, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state_2}, + sg2 = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id_2, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state_2}, + ) ) self.store.register_event_id_state_group(prev_event_id_2, sg2) - return self.state.compute_event_context(event) + result = yield defer.ensureDeferred(self.state.compute_event_context(event)) + return result diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5c2817cf28..b89798336c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -14,7 +14,6 @@ import json -import six from mock import Mock from twisted.test.proto_helpers import MemoryReactorClock @@ -60,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) - self.assertIsInstance(channel.json_body["session"], six.text_type) + self.assertIsInstance(channel.json_body["session"], str) self.assertIsInstance(channel.json_body["flows"], list) for flow in channel.json_body["flows"]: @@ -125,6 +124,6 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) - self.assertIsInstance(channel.json_body["user_id"], six.text_type) - self.assertIsInstance(channel.json_body["access_token"], six.text_type) - self.assertIsInstance(channel.json_body["device_id"], six.text_type) + self.assertIsInstance(channel.json_body["user_id"], str) + self.assertIsInstance(channel.json_body["access_token"], str) + self.assertIsInstance(channel.json_body["device_id"], str) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 7b345b03bb..508aeba078 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,7 +17,7 @@ """ Utilities for running the unit tests """ -from typing import Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: # if next didn't raise, the awaitable hasn't completed. raise Exception("awaitable has not yet completed") + + +async def make_awaitable(result: Any): + """Create an awaitable that just returns a result.""" + return result diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 431e9f8e5e..fb1ca90336 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,25 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection - -from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ -def inject_member_event( +async def inject_member_event( hs: synapse.server.HomeServer, room_id: str, sender: str, @@ -48,7 +43,7 @@ def inject_member_event( if extra_content: content.update(extra_content) - return inject_event( + return await inject_event( hs, room_id=room_id, type=EventTypes.Member, @@ -59,10 +54,10 @@ def inject_member_event( ) -def inject_event( +async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -74,37 +69,27 @@ def inject_event( prev_event_ids: prev_events for the event. If not specified, will be looked up kwargs: fields for the event to be created """ - test_reactor = hs.get_reactor() + event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - event, context = create_event(hs, room_version, prev_event_ids, **kwargs) - - d = hs.get_storage().persistence.persist_event(event, context) - test_reactor.advance(0) - get_awaitable_result(d) + await hs.get_storage().persistence.persist_event(event, context) return event -def create_event( +async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: - test_reactor = hs.get_reactor() - if room_version is None: - d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) - test_reactor.advance(0) - room_version = get_awaitable_result(d) + room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - d = hs.get_event_creation_handler().create_new_client_event( + event, context = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) - test_reactor.advance(0) - event, context = get_awaitable_result(d) return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f7381b2885..510b630114 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -37,10 +37,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -53,7 +52,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - self.inject_visibility("@admin:hs", "joined") + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): yield self.inject_room_member("@resident%i:hs" % i) @@ -64,8 +63,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -99,11 +98,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) # the erasey user gets erased - yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") + yield defer.ensureDeferred( + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) for i in range(0, len(events_to_filter)): @@ -137,10 +138,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) + ) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) ) - yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -158,11 +161,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -179,11 +184,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -265,8 +272,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): storage.main = test_store storage.state = test_store - filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) ) logger.info("Filtering took %f seconds", time.time() - start) @@ -287,7 +294,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): test_large_room.skip = "Disabled by default because it's slow" -class _TestStore(object): +class _TestStore: """Implements a few methods of the DataStore, so that we can test filter_events_for_server diff --git a/tests/unittest.py b/tests/unittest.py index 6b6f224e9c..3cb55a7e96 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -229,7 +229,7 @@ class HomeserverTestCase(TestCase): self.site = SynapseSite( logger_name="synapse.access.http.fake", site_tag="test", - config={}, + config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", ) @@ -241,20 +241,20 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: - def get_user_by_access_token(token=None, allow_guest=False): - return succeed( - { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - ) - - def get_user_by_req(request, allow_guest=False, rights="access"): - return succeed( - create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + async def get_user_by_req(request, allow_guest=False, rights="access"): + return create_requester( + UserID.from_string(self.helper.auth_user_id), + 1, + False, + False, + None, ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -422,8 +422,8 @@ class HomeserverTestCase(TestCase): async def run_bg_updates(): with LoggingContext("run_bg_updates", request="run_bg_updates-1"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) event, context = self.get_success( event_creator.create_event( @@ -571,7 +571,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db.simple_insert( + self.hs.get_datastore().db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", @@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase): user: MXID of the user to inject the membership for. membership: The membership type. """ - event_injection.inject_member_event(self.hs, room, user, membership) + self.get_success( + event_injection.inject_member_event(self.hs, room, user, membership) + ) class FederatingHomeserverTestCase(HomeserverTestCase): @@ -612,7 +614,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): """ def prepare(self, reactor, clock, homeserver): - class Authenticator(object): + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 4d2b9e0d64..677e925477 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -88,7 +88,7 @@ class CacheTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -122,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase): def test_cache_num_args(self): """Only the first num_args arguments should matter to the cache""" - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -156,7 +156,7 @@ class DescriptorTestCase(unittest.TestCase): """If the wrapped function throws synchronously, things should continue to work """ - class Cls(object): + class Cls: @cached() def fn(self, arg1): raise SynapseError(100, "mai spoon iz too big!!1") @@ -180,7 +180,7 @@ class DescriptorTestCase(unittest.TestCase): complete_lookup = defer.Deferred() - class Cls(object): + class Cls: @descriptors.cached() def fn(self, arg1): @defer.inlineCallbacks @@ -223,7 +223,7 @@ class DescriptorTestCase(unittest.TestCase): """Check that the cache sets and restores logcontexts correctly when the lookup function throws an exception""" - class Cls(object): + class Cls: @descriptors.cached() def fn(self, arg1): @defer.inlineCallbacks @@ -263,7 +263,7 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache_default_args(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -300,7 +300,7 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() def test_cache_iterable(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -336,7 +336,7 @@ class DescriptorTestCase(unittest.TestCase): """If the wrapped function throws synchronously, things should continue to work """ - class Cls(object): + class Cls: @descriptors.cached(iterable=True) def fn(self, arg1): raise SynapseError(100, "mai spoon iz too big!!1") @@ -358,7 +358,7 @@ class DescriptorTestCase(unittest.TestCase): class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): assert current_context().request == "c1" # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() assert current_context().request == "c1" return self.mock(args1, arg2) @@ -408,7 +408,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def test_invalidate(self): """Make sure that invalidation callbacks are called.""" - class Cls(object): + class Cls: def __init__(self): self.mock = mock.Mock() @@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() return self.mock(args1, arg2) obj = Cls() diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index e90e08d1c0..2012263184 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -15,9 +15,9 @@ import threading +from io import StringIO from mock import NonCallableMock -from six import StringIO from twisted.internet import defer, reactor @@ -112,7 +112,7 @@ class FileConsumerTests(unittest.TestCase): self.assertTrue(string_file.closed) -class DummyPullProducer(object): +class DummyPullProducer: def __init__(self): self.consumer = None self.deferred = defer.Deferred() @@ -134,7 +134,7 @@ class DummyPullProducer(object): return d -class BlockingStringWrite(object): +class BlockingStringWrite: def __init__(self): self.buffer = "" self.closed = False diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index ca3858b184..0e52811948 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import range - from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 95301c013c..58ee918f65 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable(self): - # a function which retuns an incomplete deferred, but doesn't follow + # a function which returns an incomplete deferred, but doesn't follow # the synapse rules. def blocking_function(): d = defer.Deferred() @@ -183,7 +183,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_await(self): - # an async function which retuns an incomplete coroutine, but doesn't + # an async function which returns an incomplete coroutine, but doesn't # follow the synapse rules. async def blocking_function(): diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 9e348694ad..5f46ed0cef 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request self.pump(1) @@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase): with limiter: pass - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) def test_limiter(self): """General test case which walks through the process of a failing request""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], failure_ts) self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) # now if we try again we should get a failure - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - self.failureResultOf(d, NotRetryingDestination) + self.get_failure( + get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + ) # # advance the clock and try again # self.pump(MIN_RETRY_INTERVAL) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], retry_ts) self.assertGreaterEqual( @@ -109,10 +91,8 @@ class RetryLimiterTestCase(HomeserverTestCase): # # one more go, with success # - self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) with limiter: @@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase): # wait for the update to land self.pump() - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index bd32e2cee7..d3dea3b52a 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.async_helpers import ReadWriteLock @@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase): rwlock.read(key), # 5 rwlock.write(key), # 6 ] + ds = [defer.ensureDeferred(d) for d in ds] self._assert_called_before_not_after(ds, 2) @@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase): with ds[6].result: pass - d = rwlock.write(key) + d = defer.ensureDeferred(rwlock.write(key)) self.assertTrue(d.called) with d.result: pass - d = rwlock.read(key) + d = defer.ensureDeferred(rwlock.read(key)) self.assertTrue(d.called) with d.result: pass diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py index 4f4da29a98..8491f7cc83 100644 --- a/tests/util/test_stringutils.py +++ b/tests/util/test_stringutils.py @@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase): "_--something==_", "...--==-18913", "8Dj2odd-e9asd.cd==_--ddas-secret-", - # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766 - # To be removed in a future release - "SECRET:1234567890", ] bad = [ diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py new file mode 100644 index 0000000000..5513724d87 --- /dev/null +++ b/tests/util/test_threepids.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.threepids import canonicalise_email + +from tests.unittest import HomeserverTestCase + + +class CanonicaliseEmailTests(HomeserverTestCase): + def test_no_at(self): + with self.assertRaises(ValueError): + canonicalise_email("address-without-at.bar") + + def test_two_at(self): + with self.assertRaises(ValueError): + canonicalise_email("foo@foo@test.bar") + + def test_bad_format(self): + with self.assertRaises(ValueError): + canonicalise_email("user@bad.example.net@good.example.com") + + def test_valid_format(self): + self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") + + def test_domain_to_lower(self): + self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") + + def test_domain_with_umlaut(self): + self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") + + def test_address_casefold(self): + self.assertEqual( + canonicalise_email("Strauß@Example.com"), "strauss@example.com" + ) + + def test_address_trim(self): + self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") diff --git a/tests/utils.py b/tests/utils.py index 59c020a051..4673872f88 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,9 +21,9 @@ import time import uuid import warnings from inspect import getcallargs +from urllib import parse as urlparse from mock import Mock, patch -from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor @@ -154,6 +154,10 @@ def default_config(name, parse=False): "account": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins": { + "local": {"per_second": 10000, "burst_count": 10000}, + "remote": {"per_second": 10000, "burst_count": 10000}, + }, "saml2_enabled": False, "public_baseurl": None, "default_identity_server": None, @@ -168,6 +172,7 @@ def default_config(name, parse=False): # background, which upsets the test runner. "update_user_directory": False, "caches": {"global_factor": 1}, + "listeners": [{"port": 0, "type": "http"}], } if parse: @@ -467,7 +472,7 @@ class MockHttpResource(HttpServer): self.callbacks.append((method, path_pattern, callback)) -class MockKey(object): +class MockKey: alg = "mock_alg" version = "mock_version" signature = b"\x9a\x87$" @@ -486,7 +491,7 @@ class MockKey(object): return b"<fake_encoded_key>" -class MockClock(object): +class MockClock: now = 1000 def __init__(self): @@ -563,7 +568,7 @@ def _format_call(args, kwargs): ) -class DeferredMockCallable(object): +class DeferredMockCallable: """A callable instance that stores a set of pending call expectations and return values for them. It allows a unit test to assert that the given set of function calls are eventually made, by awaiting on them to be called. @@ -637,14 +642,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -652,7 +651,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -670,6 +669,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield event_creation_handler.create_new_client_event(builder) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index 463a34d137..df473bd234 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,6 @@ envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort [base] -basepython = python3.7 deps = mock python-subunit @@ -120,27 +119,26 @@ commands = [testenv:check_codestyle] skip_install = True -basepython = python3.6 deps = flake8 flake8-comprehensions - black==19.10b0 # We pin so that our tests don't start failing on new releases of black. + # We pin so that our tests don't start failing on new releases of black. + black==19.10b0 commands = python -m black --check --diff . - /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" + /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}" {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort] skip_install = True -deps = isort -commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts" +deps = isort==5.0.3 +commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" [testenv:check-newsfragment] skip_install = True deps = towncrier>=18.6.0rc1 commands = python -m towncrier.check --compare-with=origin/develop -basepython = python3.6 [testenv:check-sampleconfig] commands = {toxinidir}/scripts-dev/generate_sample_config --check @@ -171,48 +169,10 @@ commands= skip_install = True deps = {[base]deps} - mypy==0.750 + mypy==0.782 mypy-zope -env = - MYPYPATH = stubs/ extras = all -commands = mypy \ - synapse/api \ - synapse/appservice \ - synapse/config \ - synapse/event_auth.py \ - synapse/events/spamcheck.py \ - synapse/federation \ - synapse/handlers/auth.py \ - synapse/handlers/cas_handler.py \ - synapse/handlers/directory.py \ - synapse/handlers/oidc_handler.py \ - synapse/handlers/presence.py \ - synapse/handlers/room_member.py \ - synapse/handlers/room_member_worker.py \ - synapse/handlers/saml_handler.py \ - synapse/handlers/sync.py \ - synapse/handlers/ui_auth \ - synapse/http/server.py \ - synapse/http/site.py \ - synapse/logging/ \ - synapse/metrics \ - synapse/module_api \ - synapse/push/pusherpool.py \ - synapse/push/push_rule_evaluator.py \ - synapse/replication \ - synapse/rest \ - synapse/spam_checker_api \ - synapse/storage/data_stores/main/ui_auth.py \ - synapse/storage/database.py \ - synapse/storage/engines \ - synapse/storage/util \ - synapse/streams \ - synapse/util/caches/stream_change_cache.py \ - tests/replication \ - tests/test_utils \ - tests/rest/client/v2_alpha/test_auth.py \ - tests/util/test_stream_change_cache.py +commands = mypy # To find all folders that pass mypy you run: # |