diff options
author | Jason Robinson <jasonr@matrix.org> | 2021-01-23 21:41:35 +0200 |
---|---|---|
committer | Jason Robinson <jasonr@matrix.org> | 2021-01-23 21:41:35 +0200 |
commit | 8965b6cfec8a1de847efe3d1be4b9babf4622e2e (patch) | |
tree | 4551f104ee2ce840689aa5ecffa939938482ffd5 | |
parent | Add depth and received_ts to forward_extremities admin API response (diff) | |
parent | Return a 404 if no valid thumbnail is found. (#9163) (diff) | |
download | synapse-8965b6cfec8a1de847efe3d1be4b9babf4622e2e.tar.xz |
Merge branch 'develop' into jaywink/admin-forward-extremities
209 files changed, 7802 insertions, 2405 deletions
diff --git a/.gitignore b/.gitignore index 2bccf19997..2cef1b0a5a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ _trial_temp/ _trial_temp*/ /out +.DS_Store # stuff that is likely to exist when you run a server locally /*.db diff --git a/CHANGES.md b/CHANGES.md index 2aebe92cac..1c64007e54 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,118 @@ -Synapse 1.25.0rc1 (2021-01-06) +Synapse 1.26.0rc1 (2021-01-20) ============================== +This release brings a new schema version for Synapse and rolling back to a previous +version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details +on these changes and for general upgrade guidance. + +Features +-------- + +- Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177)) +- During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091)) +- Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159)) +- Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024)) +- Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984)) +- Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086)) +- Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932)) +- Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948)) +- Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130)) +- Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068)) +- Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092)) +- Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166)) + + +Bugfixes +-------- + +- Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023)) +- Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028)) +- Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051)) +- Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053)) +- Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054)) +- Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059)) +- Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070)) +- Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071)) +- Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116)) +- Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117)) +- Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128)) +- Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108)) +- Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145)) +- Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161)) + + +Improved Documentation +---------------------- + +- Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997)) +- Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035)) +- Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040)) +- Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057)) +- Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151)) + + +Deprecations and Removals +------------------------- + +- Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039)) + + +Internal Changes +---------------- + +- Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124)) +- Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939)) +- Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016)) +- Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018)) +- Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025)) +- Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030)) +- Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031)) +- Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033)) +- Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038)) +- Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041)) +- Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055)) +- Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058)) +- Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063)) +- Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069)) +- Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080)) +- Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093)) +- Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098)) +- Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106)) +- Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112)) +- Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125)) +- Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144)) +- Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146)) +- Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157)) + + +Synapse 1.25.0 (2021-01-13) +=========================== + +Ending Support for Python 3.5 and Postgres 9.5 +---------------------------------------------- + +With this release, the Synapse team is announcing a formal deprecation policy for our platform dependencies, like Python and PostgreSQL: + +All future releases of Synapse will follow the upstream end-of-life schedules. + +Which means: + +* This is the last release which guarantees support for Python 3.5. +* We will end support for PostgreSQL 9.5 early next month. +* We will end support for Python 3.6 and PostgreSQL 9.6 near the end of the year. + +Crucially, this means __we will not produce .deb packages for Debian 9 (Stretch) or Ubuntu 16.04 (Xenial)__ beyond the transition period described below. + +The website https://endoflife.date/ has convenient summaries of the support schedules for projects like [Python](https://endoflife.date/python) and [PostgreSQL](https://endoflife.date/postgresql). + +If you are unable to upgrade your environment to a supported version of Python or Postgres, we encourage you to consider using the [Synapse Docker images](./INSTALL.md#docker-images-and-ansible-playbooks) instead. + +### Transition Period + +We will make a good faith attempt to avoid breaking compatibility in all releases through the end of March 2021. However, critical security vulnerabilities in dependencies or other unanticipated circumstances may arise which necessitate breaking compatibility earlier. + +We intend to continue producing .deb packages for Debian 9 (Stretch) and Ubuntu 16.04 (Xenial) through the transition period. + Removal warning --------------- @@ -12,6 +124,15 @@ are deprecated and will be removed in a future release. They will be replaced by `POST /_synapse/admin/v1/rooms/<room_id>/delete` replaces `POST /_synapse/admin/v1/purge_room` and `POST /_synapse/admin/v1/shutdown_room/<room_id>`. +Bugfixes +-------- + +- Fix HTTP proxy support when using a proxy that is on a blacklisted IP. Introduced in v1.25.0rc1. Contributed by @Bubu. ([\#9084](https://github.com/matrix-org/synapse/issues/9084)) + + +Synapse 1.25.0rc1 (2021-01-06) +============================== + Features -------- @@ -61,7 +182,7 @@ Improved Documentation - Combine related media admin API docs. ([\#8839](https://github.com/matrix-org/synapse/issues/8839)) - Fix an error in the documentation for the SAML username mapping provider. ([\#8873](https://github.com/matrix-org/synapse/issues/8873)) - Clarify comments around template directories in `sample_config.yaml`. ([\#8891](https://github.com/matrix-org/synapse/issues/8891)) -- Moved instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987)) +- Move instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by @fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987)) - Update the example value of `group_creation_prefix` in the sample configuration. ([\#8992](https://github.com/matrix-org/synapse/issues/8992)) - Link the Synapse developer room to the development section in the docs. ([\#9002](https://github.com/matrix-org/synapse/issues/9002)) diff --git a/INSTALL.md b/INSTALL.md index 598ddceb8c..d405d9fe55 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -190,7 +190,8 @@ via brew and inform `pip` about it so that `psycopg2` builds: ```sh brew install openssl@1.1 -export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ +export LDFLAGS="-L/usr/local/opt/openssl/lib" +export CPPFLAGS="-I/usr/local/opt/openssl/include" ``` ##### OpenSUSE @@ -257,7 +258,7 @@ for a number of platforms. #### Docker images and Ansible playbooks -There is an offical synapse image available at +There is an official synapse image available at <https://hub.docker.com/r/matrixdotorg/synapse> which can be used with the docker-compose file available at [contrib/docker](contrib/docker). Further information on this including configuration options is available in the README diff --git a/README.rst b/README.rst index 31ae5cc578..d872b11f57 100644 --- a/README.rst +++ b/README.rst @@ -243,7 +243,7 @@ Then update the ``users`` table in the database:: Synapse Development =================== -Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org) +Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_ Before setting up a development environment for synapse, make sure you have the system dependencies (such as the python header files) installed - see @@ -280,6 +280,27 @@ differ):: PASSED (skips=15, successes=1322) +We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082` + + ./demo/start.sh + +(to stop, you can use `./demo/stop.sh`) + +If you just want to start a single instance of the app and run it directly:: + + # Create the homeserver.yaml config once + python -m synapse.app.homeserver \ + --server-name my.domain.name \ + --config-path homeserver.yaml \ + --generate-config \ + --report-stats=[yes|no] + + # Start the app + python -m synapse.app.homeserver --config-path homeserver.yaml + + + + Running the Integration Tests ============================= diff --git a/UPGRADE.rst b/UPGRADE.rst index 54a40bd42f..d09dbd4e21 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -5,6 +5,16 @@ Before upgrading check if any special steps are required to upgrade from the version you currently have installed to the current version of Synapse. The extra instructions that may be required are listed later in this document. +* Check that your versions of Python and PostgreSQL are still supported. + + Synapse follows upstream lifecycles for `Python`_ and `PostgreSQL`_, and + removes support for versions which are no longer maintained. + + The website https://endoflife.date also offers convenient summaries. + + .. _Python: https://devguide.python.org/devcycle/#end-of-life-branches + .. _PostgreSQL: https://www.postgresql.org/support/versioning/ + * If Synapse was installed using `prebuilt packages <INSTALL.md#prebuilt-packages>`_, you will need to follow the normal process for upgrading those packages. @@ -75,9 +85,71 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.26.0 +==================== + +Rolling back to v1.25.0 after a failed upgrade +---------------------------------------------- + +v1.26.0 includes a lot of large changes. If something problematic occurs, you +may want to roll-back to a previous version of Synapse. Because v1.26.0 also +includes a new database schema version, reverting that version is also required +alongside the generic rollback instructions mentioned above. In short, to roll +back to v1.25.0 you need to: + +1. Stop the server +2. Decrease the schema version in the database: + + .. code:: sql + + UPDATE schema_version SET version = 58; + +3. Delete the ignored users & chain cover data: + + .. code:: sql + + DROP TABLE IF EXISTS ignored_users; + UPDATE rooms SET has_auth_chain_index = false; + + For PostgreSQL run: + + .. code:: sql + + TRUNCATE event_auth_chain_links; + TRUNCATE event_auth_chains; + + For SQLite run: + + .. code:: sql + + DELETE FROM event_auth_chain_links; + DELETE FROM event_auth_chains; + +4. Mark the deltas as not run (so they will re-run on upgrade). + + .. code:: sql + + DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py"; + DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql"; + +5. Downgrade Synapse by following the instructions for your installation method + in the "Rolling back to older versions" section above. + Upgrading to v1.25.0 ==================== +Last release supporting Python 3.5 +---------------------------------- + +This is the last release of Synapse which guarantees support with Python 3.5, +which passed its upstream End of Life date several months ago. + +We will attempt to maintain support through March 2021, but without guarantees. + +In the future, Synapse will follow upstream schedules for ending support of +older versions of Python and PostgreSQL. Please upgrade to at least Python 3.6 +and PostgreSQL 9.6 as soon as possible. + Blacklisting IP ranges ---------------------- diff --git a/changelog.d/8939.misc b/changelog.d/8939.misc deleted file mode 100644 index bf94135fd5..0000000000 --- a/changelog.d/8939.misc +++ /dev/null @@ -1 +0,0 @@ -Various clean-ups to the structured logging and logging context code. diff --git a/changelog.d/8984.feature b/changelog.d/8984.feature deleted file mode 100644 index 4db629746e..0000000000 --- a/changelog.d/8984.feature +++ /dev/null @@ -1 +0,0 @@ -Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. diff --git a/changelog.d/9015.feature b/changelog.d/9015.feature deleted file mode 100644 index 01a24dcf49..0000000000 --- a/changelog.d/9015.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9017.feature b/changelog.d/9017.feature deleted file mode 100644 index 01a24dcf49..0000000000 --- a/changelog.d/9017.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9018.misc b/changelog.d/9018.misc deleted file mode 100644 index bb31eb4a46..0000000000 --- a/changelog.d/9018.misc +++ /dev/null @@ -1 +0,0 @@ -Ignore date-rotated homeserver logs saved to disk. diff --git a/changelog.d/9023.bugfix b/changelog.d/9023.bugfix deleted file mode 100644 index deae64d933..0000000000 --- a/changelog.d/9023.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a longstanding issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. diff --git a/changelog.d/9024.feature b/changelog.d/9024.feature deleted file mode 100644 index 073dafbf83..0000000000 --- a/changelog.d/9024.feature +++ /dev/null @@ -1 +0,0 @@ -Improved performance when calculating ignored users in large rooms. diff --git a/changelog.d/9028.bugfix b/changelog.d/9028.bugfix deleted file mode 100644 index 66666886a4..0000000000 --- a/changelog.d/9028.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where some caches could grow larger than configured. diff --git a/changelog.d/9030.misc b/changelog.d/9030.misc deleted file mode 100644 index 267cfbf9f9..0000000000 --- a/changelog.d/9030.misc +++ /dev/null @@ -1 +0,0 @@ -Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. diff --git a/changelog.d/9031.misc b/changelog.d/9031.misc deleted file mode 100644 index f43611c385..0000000000 --- a/changelog.d/9031.misc +++ /dev/null @@ -1 +0,0 @@ -Fix running unit tests when optional dependencies are not installed. diff --git a/changelog.d/9033.misc b/changelog.d/9033.misc deleted file mode 100644 index e9a305c0e8..0000000000 --- a/changelog.d/9033.misc +++ /dev/null @@ -1 +0,0 @@ -Allow bumping schema version when using split out state database. diff --git a/changelog.d/9035.doc b/changelog.d/9035.doc deleted file mode 100644 index 2a7f0db518..0000000000 --- a/changelog.d/9035.doc +++ /dev/null @@ -1 +0,0 @@ -Corrected a typo in the `systemd-with-workers` documentation. diff --git a/changelog.d/9036.feature b/changelog.d/9036.feature deleted file mode 100644 index 01a24dcf49..0000000000 --- a/changelog.d/9036.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9038.misc b/changelog.d/9038.misc deleted file mode 100644 index 5b9e21a1db..0000000000 --- a/changelog.d/9038.misc +++ /dev/null @@ -1 +0,0 @@ -Configure the linters to run on a consistent set of files. diff --git a/changelog.d/9039.removal b/changelog.d/9039.removal deleted file mode 100644 index fb99283ed8..0000000000 --- a/changelog.d/9039.removal +++ /dev/null @@ -1 +0,0 @@ -Remove broken and unmaintained `demo/webserver.py` script. diff --git a/changelog.d/9041.misc b/changelog.d/9041.misc deleted file mode 100644 index 4952fbe8a2..0000000000 --- a/changelog.d/9041.misc +++ /dev/null @@ -1 +0,0 @@ -Various cleanups to device inbox store. diff --git a/changelog.d/9042.feature b/changelog.d/9042.feature deleted file mode 100644 index 4ec319f1f2..0000000000 --- a/changelog.d/9042.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9043.feature b/changelog.d/9043.feature deleted file mode 100644 index 4ec319f1f2..0000000000 --- a/changelog.d/9043.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9044.feature b/changelog.d/9044.feature deleted file mode 100644 index 4ec319f1f2..0000000000 --- a/changelog.d/9044.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc new file mode 100644 index 0000000000..7f1886a0de --- /dev/null +++ b/changelog.d/9045.misc @@ -0,0 +1 @@ +Add tests to `test_user.UsersListTestCase` for List Users Admin API. \ No newline at end of file diff --git a/changelog.d/9051.bugfix b/changelog.d/9051.bugfix deleted file mode 100644 index 272be9d7a3..0000000000 --- a/changelog.d/9051.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix error handling during insertion of client IPs into the database. diff --git a/changelog.d/9053.bugfix b/changelog.d/9053.bugfix deleted file mode 100644 index 3d8bbf11a1..0000000000 --- a/changelog.d/9053.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix bug where we didn't correctly record CPU time spent in 'on_new_event' block. diff --git a/changelog.d/9054.bugfix b/changelog.d/9054.bugfix deleted file mode 100644 index 0bfe951f17..0000000000 --- a/changelog.d/9054.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a minor bug which could cause confusing error messages from invalid configurations. diff --git a/changelog.d/9057.doc b/changelog.d/9057.doc deleted file mode 100644 index d16686e7dc..0000000000 --- a/changelog.d/9057.doc +++ /dev/null @@ -1 +0,0 @@ -Add missing user_mapping_provider configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. diff --git a/changelog.d/9129.misc b/changelog.d/9129.misc new file mode 100644 index 0000000000..7800be3e7e --- /dev/null +++ b/changelog.d/9129.misc @@ -0,0 +1 @@ +Various improvements to the federation client. diff --git a/changelog.d/9135.doc b/changelog.d/9135.doc new file mode 100644 index 0000000000..d11ba70de4 --- /dev/null +++ b/changelog.d/9135.doc @@ -0,0 +1 @@ +Add link to Matrix VoIP tester for turn-howto. diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix new file mode 100644 index 0000000000..c51cf6ca80 --- /dev/null +++ b/changelog.d/9163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). diff --git a/changelog.d/9176.misc b/changelog.d/9176.misc new file mode 100644 index 0000000000..9c41d7b0f9 --- /dev/null +++ b/changelog.d/9176.misc @@ -0,0 +1 @@ +Speed up chain cover calculation when persisting a batch of state events at once. diff --git a/changelog.d/9180.misc b/changelog.d/9180.misc new file mode 100644 index 0000000000..69dd86110d --- /dev/null +++ b/changelog.d/9180.misc @@ -0,0 +1 @@ +Add a `long_description_type` to the package metadata. diff --git a/changelog.d/9181.misc b/changelog.d/9181.misc new file mode 100644 index 0000000000..7820d09cd0 --- /dev/null +++ b/changelog.d/9181.misc @@ -0,0 +1 @@ +Speed up batch insertion when using PostgreSQL. diff --git a/changelog.d/9184.misc b/changelog.d/9184.misc new file mode 100644 index 0000000000..70da3d6cf5 --- /dev/null +++ b/changelog.d/9184.misc @@ -0,0 +1 @@ +Emit an error at startup if different Identity Providers are configured with the same `idp_id`. diff --git a/changelog.d/9188.misc b/changelog.d/9188.misc new file mode 100644 index 0000000000..7820d09cd0 --- /dev/null +++ b/changelog.d/9188.misc @@ -0,0 +1 @@ +Speed up batch insertion when using PostgreSQL. diff --git a/changelog.d/9189.misc b/changelog.d/9189.misc new file mode 100644 index 0000000000..9a5740aac2 --- /dev/null +++ b/changelog.d/9189.misc @@ -0,0 +1 @@ +Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. diff --git a/changelog.d/9190.misc b/changelog.d/9190.misc new file mode 100644 index 0000000000..1b0cc56a92 --- /dev/null +++ b/changelog.d/9190.misc @@ -0,0 +1 @@ +Improve performance of concurrent use of `StreamIDGenerators`. diff --git a/changelog.d/9191.misc b/changelog.d/9191.misc new file mode 100644 index 0000000000..b4bc6be13a --- /dev/null +++ b/changelog.d/9191.misc @@ -0,0 +1 @@ +Add some missing source directories to the automatic linting script. \ No newline at end of file diff --git a/changelog.d/9193.bugfix b/changelog.d/9193.bugfix new file mode 100644 index 0000000000..5233ffc3e7 --- /dev/null +++ b/changelog.d/9193.bugfix @@ -0,0 +1 @@ +Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1. diff --git a/changelog.d/9195.bugfix b/changelog.d/9195.bugfix new file mode 100644 index 0000000000..5233ffc3e7 --- /dev/null +++ b/changelog.d/9195.bugfix @@ -0,0 +1 @@ +Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1. diff --git a/debian/changelog b/debian/changelog index 6b819d201d..1c6308e3a2 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,20 @@ +matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium + + * Remove dependency on `python3-distutils`. + + -- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000 + +matrix-synapse-py3 (1.25.0) stable; urgency=medium + + [ Dan Callahan ] + * Update dependencies to account for the removal of the transitional + dh-systemd package from Debian Bullseye. + + [ Synapse Packaging team ] + * New synapse release 1.25.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 13 Jan 2021 10:14:55 +0000 + matrix-synapse-py3 (1.24.0) stable; urgency=medium * New synapse release 1.24.0. diff --git a/debian/control b/debian/control index bae14b41e4..8167a901a4 100644 --- a/debian/control +++ b/debian/control @@ -3,9 +3,11 @@ Section: contrib/python Priority: extra Maintainer: Synapse Packaging team <packages@matrix.org> # keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv. +# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial +# On all other supported releases, it's merely a transitional package which +# does nothing but depends on debhelper (> 9.20160709) Build-Depends: - debhelper (>= 9), - dh-systemd, + debhelper (>= 9.20160709) | dh-systemd, dh-virtualenv (>= 1.1), libsystemd-dev, libpq-dev, @@ -29,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1) Depends: adduser, debconf, - python3-distutils|libpython3-stdlib (<< 3.6), ${misc:Depends}, ${shlibs:Depends}, ${synapse:pydepends}, diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index 2b7f01f7f7..e529293803 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -50,17 +50,22 @@ FROM ${distro} ARG distro="" ENV distro ${distro} +# Python < 3.7 assumes LANG="C" means ASCII-only and throws on printing unicode +# http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + # Install the build dependencies # # NB: keep this list in sync with the list of build-deps in debian/control # TODO: it would be nice to do that automatically. +# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial +# it's a transitional package on all other, more recent releases RUN apt-get update -qq -o Acquire::Languages=none \ && env DEBIAN_FRONTEND=noninteractive apt-get install \ -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ build-essential \ debhelper \ devscripts \ - dh-systemd \ libsystemd-dev \ lsb-release \ pkg-config \ @@ -70,7 +75,10 @@ RUN apt-get update -qq -o Acquire::Languages=none \ python3-venv \ sqlite3 \ libpq-dev \ - xmlsec1 + xmlsec1 \ + && ( env DEBIAN_FRONTEND=noninteractive apt-get install \ + -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ + dh-systemd || true ) COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml index a808485c12..2ed570a5d1 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml @@ -198,12 +198,10 @@ old_signing_keys: {} key_refresh_interval: "1d" # 1 Day. # The trusted servers to download signing keys from. -perspectives: - servers: - "matrix.org": - verify_keys: - "ed25519:auto": - key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +trusted_key_servers: + - server_name: matrix.org + verify_keys: + "ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" password_config: enabled: true diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index dfb8c5d751..90faeaaef0 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) + * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) @@ -123,6 +124,29 @@ The following fields are returned in the JSON response body: * `num_quarantined`: integer - The number of media items successfully quarantined +## Protecting media from being quarantined + +This API protects a single piece of local media from being quarantined using the +above APIs. This is useful for sticker packs and other shared media which you do +not want to get quarantined, especially when +[quarantining media in a room](#quarantining-media-in-a-room). + +Request: + +``` +POST /_synapse/admin/v1/media/protect/<media_id> + +{} +``` + +Where `media_id` is in the form of `abcdefg12345...`. + +Response: + +```json +{} +``` + # Delete local media This API deletes the *local* media from the disk of your own server. This includes any local thumbnails and copies of media downloaded from diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index e4d6f8203b..b3d413cf57 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -98,6 +98,8 @@ Body parameters: - ``deactivated``, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to ``false`` for new accounts. + A user cannot be erased by deactivating with this API. For details on deactivating users see + `Deactivate Account <#deactivate-account>`_. If the user already exists then optional parameters default to the current value. @@ -248,6 +250,25 @@ server admin: see `README.rst <README.rst>`_. The erase parameter is optional and defaults to ``false``. An empty body may be passed for backwards compatibility. +The following actions are performed when deactivating an user: + +- Try to unpind 3PIDs from the identity server +- Remove all 3PIDs from the homeserver +- Delete all devices and E2EE keys +- Delete all access tokens +- Delete the password hash +- Removal from all rooms the user is a member of +- Remove the user from the user directory +- Reject all pending invites +- Remove all account validity information related to the user + +The following additional actions are performed during deactivation if``erase`` +is set to ``true``: + +- Remove the user's display name +- Remove the user's avatar URL +- Mark the user as erased + Reset password ============== @@ -337,6 +358,10 @@ A response body like the following is returned: "total": 2 } +The server returns the list of rooms of which the user and the server +are member. If the user is local, all the rooms of which the user is +member are returned. + **Parameters** The following parameters should be set in the URL: diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot new file mode 100644 index 0000000000..978d579ada --- /dev/null +++ b/docs/auth_chain_diff.dot @@ -0,0 +1,32 @@ +digraph auth { + nodesep=0.5; + rankdir="RL"; + + C [label="Create (1,1)"]; + + BJ [label="Bob's Join (2,1)", color=red]; + BJ2 [label="Bob's Join (2,2)", color=red]; + BJ2 -> BJ [color=red, dir=none]; + + subgraph cluster_foo { + A1 [label="Alice's invite (4,1)", color=blue]; + A2 [label="Alice's Join (4,2)", color=blue]; + A3 [label="Alice's Join (4,3)", color=blue]; + A3 -> A2 -> A1 [color=blue, dir=none]; + color=none; + } + + PL1 [label="Power Level (3,1)", color=darkgreen]; + PL2 [label="Power Level (3,2)", color=darkgreen]; + PL2 -> PL1 [color=darkgreen, dir=none]; + + {rank = same; C; BJ; PL1; A1;} + + A1 -> C [color=grey]; + A1 -> BJ [color=grey]; + PL1 -> C [color=grey]; + BJ2 -> PL1 [penwidth=2]; + + A3 -> PL2 [penwidth=2]; + A1 -> PL1 -> BJ -> C [penwidth=2]; +} diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png new file mode 100644 index 0000000000..771c07308f --- /dev/null +++ b/docs/auth_chain_diff.dot.png Binary files differdiff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md new file mode 100644 index 0000000000..30f72a70da --- /dev/null +++ b/docs/auth_chain_difference_algorithm.md @@ -0,0 +1,108 @@ +# Auth Chain Difference Algorithm + +The auth chain difference algorithm is used by V2 state resolution, where a +naive implementation can be a significant source of CPU and DB usage. + +### Definitions + +A *state set* is a set of state events; e.g. the input of a state resolution +algorithm is a collection of state sets. + +The *auth chain* of a set of events are all the events' auth events and *their* +auth events, recursively (i.e. the events reachable by walking the graph induced +by an event's auth events links). + +The *auth chain difference* of a collection of state sets is the union minus the +intersection of the sets of auth chains corresponding to the state sets, i.e an +event is in the auth chain difference if it is reachable by walking the auth +event graph from at least one of the state sets but not from *all* of the state +sets. + +## Breadth First Walk Algorithm + +A way of calculating the auth chain difference without calculating the full auth +chains for each state set is to do a parallel breadth first walk (ordered by +depth) of each state set's auth chain. By tracking which events are reachable +from each state set we can finish early if every pending event is reachable from +every state set. + +This can work well for state sets that have a small auth chain difference, but +can be very inefficient for larger differences. However, this algorithm is still +used if we don't have a chain cover index for the room (e.g. because we're in +the process of indexing it). + +## Chain Cover Index + +Synapse computes auth chain differences by pre-computing a "chain cover" index +for the auth chain in a room, allowing efficient reachability queries like "is +event A in the auth chain of event B". This is done by assigning every event a +*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links* +between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A` +is in the auth chain of `B`) if and only if either: + +1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s + sequence number; or +2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that + `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`. + +There are actually two potential implementations, one where we store links from +each chain to every other reachable chain (the transitive closure of the links +graph), and one where we remove redundant links (the transitive reduction of the +links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1` +would not be stored. Synapse uses the former implementations so that it doesn't +need to recurse to test reachability between chains. + +### Example + +An example auth graph would look like the following, where chains have been +formed based on type/state_key and are denoted by colour and are labelled with +`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey +are those that would be remove in the second implementation described above). + +![Example](auth_chain_diff.dot.png) + +Note that we don't include all links between events and their auth events, as +most of those links would be redundant. For example, all events point to the +create event, but each chain only needs the one link from it's base to the +create event. + +## Using the Index + +This index can be used to calculate the auth chain difference of the state sets +by looking at the chain ID and sequence numbers reachable from each state set: + +1. For every state set lookup the chain ID/sequence numbers of each state event +2. Use the index to find all chains and the maximum sequence number reachable + from each state set. +3. The auth chain difference is then all events in each chain that have sequence + numbers between the maximum sequence number reachable from *any* state set and + the minimum reachable by *all* state sets (if any). + +Note that steps 2 is effectively calculating the auth chain for each state set +(in terms of chain IDs and sequence numbers), and step 3 is calculating the +difference between the union and intersection of the auth chains. + +### Worked Example + +For example, given the above graph, we can calculate the difference between +state sets consisting of: + +1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and +2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`. + +Using the index we see that the following auth chains are reachable from each +state set: + +1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)` +2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)` + +And so, for each the ranges that are in the auth chain difference: +1. Chain 1: None, (since everything can reach the create event). +2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state + sets and the maximum reachable is `2` (corresponding to Bob's second join). +3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power + level). +4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins). + +So the final result is: Bob's second join `(2,2)`, the second power level +`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`. diff --git a/docs/openid.md b/docs/openid.md index ffa4238fff..b86ae89768 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -42,11 +42,10 @@ as follows: * For other installation mechanisms, see the documentation provided by the maintainer. -To enable the OpenID integration, you should then add an `oidc_config` section -to your configuration file (or uncomment the `enabled: true` line in the -existing section). See [sample_config.yaml](./sample_config.yaml) for some -sample settings, as well as the text below for example configurations for -specific providers. +To enable the OpenID integration, you should then add a section to the `oidc_providers` +setting in your configuration file (or uncomment one of the existing examples). +See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as +the text below for example configurations for specific providers. ## Sample configs @@ -62,20 +61,21 @@ Directory (tenant) ID as it will be used in the Azure links. Edit your Synapse config file and change the `oidc_config` section: ```yaml -oidc_config: - enabled: true - issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" - client_id: "<client id>" - client_secret: "<client secret>" - scopes: ["openid", "profile"] - authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize" - token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token" - userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" - - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username.split('@')[0] }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: microsoft + idp_name: Microsoft + issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" + client_id: "<client id>" + client_secret: "<client secret>" + scopes: ["openid", "profile"] + authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize" + token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token" + userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" + + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username.split('@')[0] }}" + display_name_template: "{{ user.name }}" ``` ### [Dex][dex-idp] @@ -103,17 +103,18 @@ Run with `dex serve examples/config-dev.yaml`. Synapse config: ```yaml -oidc_config: - enabled: true - skip_verification: true # This is needed as Dex is served on an insecure endpoint - issuer: "http://127.0.0.1:5556/dex" - client_id: "synapse" - client_secret: "secret" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.name }}" - display_name_template: "{{ user.name|capitalize }}" +oidc_providers: + - idp_id: dex + idp_name: "My Dex server" + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + client_id: "synapse" + client_secret: "secret" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.name }}" + display_name_template: "{{ user.name|capitalize }}" ``` ### [Keycloak][keycloak-idp] @@ -152,16 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to 8. Copy Secret ```yaml -oidc_config: - enabled: true - issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" - client_id: "synapse" - client_secret: "copy secret generated from above" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: keycloak + idp_name: "My KeyCloak server" + issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" + client_id: "synapse" + client_secret: "copy secret generated from above" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### [Auth0][auth0] @@ -191,16 +193,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: auth0 + idp_name: Auth0 + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitHub @@ -219,21 +222,22 @@ does not return a `sub` property, an alternative `subject_claim` has to be set. Synapse config: ```yaml -oidc_config: - enabled: true - discover: false - issuer: "https://github.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - authorization_endpoint: "https://github.com/login/oauth/authorize" - token_endpoint: "https://github.com/login/oauth/access_token" - userinfo_endpoint: "https://api.github.com/user" - scopes: ["read:user"] - user_mapping_provider: - config: - subject_claim: "id" - localpart_template: "{{ user.login }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: github + idp_name: Github + discover: false + issuer: "https://github.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: ["read:user"] + user_mapping_provider: + config: + subject_claim: "id" + localpart_template: "{{ user.login }}" + display_name_template: "{{ user.name }}" ``` ### [Google][google-idp] @@ -243,16 +247,17 @@ oidc_config: 2. add an "OAuth Client ID" for a Web Application under "Credentials". 3. Copy the Client ID and Client Secret, and add the following to your synapse config: ```yaml - oidc_config: - enabled: true - issuer: "https://accounts.google.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.given_name|lower }}" - display_name_template: "{{ user.name }}" + oidc_providers: + - idp_id: google + idp_name: Google + issuer: "https://accounts.google.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.given_name|lower }}" + display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse public baseurl]/_synapse/oidc/callback`. @@ -266,16 +271,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://id.twitch.tv/oauth2/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: twitch + idp_name: Twitch + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitLab @@ -287,16 +293,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://gitlab.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - scopes: ["openid", "read_user"] - user_profile_method: "userinfo_endpoint" - user_mapping_provider: - config: - localpart_template: '{{ user.nickname }}' - display_name_template: '{{ user.name }}' +oidc_providers: + - idp_id: gitlab + idp_name: Gitlab + issuer: "https://gitlab.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: ["openid", "read_user"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: '{{ user.nickname }}' + display_name_template: '{{ user.name }}' ``` diff --git a/docs/postgres.md b/docs/postgres.md index c30cc1fd8c..680685d04e 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -18,7 +18,7 @@ connect to a postgres database. virtualenv](../INSTALL.md#installing-from-source), you can install the library with: - ~/synapse/env/bin/pip install matrix-synapse[postgres] + ~/synapse/env/bin/pip install "matrix-synapse[postgres]" (substituting the path to your virtualenv for `~/synapse/env`, if you used a different path). You will require the postgres diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c8ae46d1b3..87bfe22237 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid # #web_client_location: https://riot.example.com/ -# The public-facing base URL that clients use to access this HS -# (not including _matrix/...). This is the same URL a user would -# enter into the 'custom HS URL' field on their client. If you -# use synapse with a reverse proxy, this should be the URL to reach -# synapse via the proxy. +# The public-facing base URL that clients use to access this Homeserver (not +# including _matrix/...). This is the same URL a user might enter into the +# 'Custom Homeserver URL' field on their client. If you use Synapse with a +# reverse proxy, this should be the URL to reach Synapse via the proxy. +# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see +# 'listeners' below). +# +# If this is left unset, it defaults to 'https://<server_name>/'. (Note that +# that will not work unless you configure Synapse or a reverse-proxy to listen +# on port 443.) # #public_baseurl: https://example.com/ @@ -1150,8 +1155,9 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -1242,8 +1248,7 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client. -# This setting is ignored unless public_baseurl is also set.) +# (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -1268,8 +1273,6 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # -# If a delegate is specified, the config option public_baseurl must also be filled out. -# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1709,141 +1712,158 @@ saml2_config: #idp_entityid: 'https://our_idp/entityid' -# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. +# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration +# and login. # -# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md -# for some example configurations. +# Options for each entry include: # -oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. - # - #enabled: true - - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. - # - #discover: false - - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. - # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". - # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. - # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{ user.preferred_username }}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{ user.given_name }} {{ user.last_name }}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{ user.birthdate }}" - +# idp_id: a unique identifier for this identity provider. Used internally +# by Synapse; should be a single word such as 'github'. +# +# Note that, if this is changed, users authenticating via that provider +# will no longer be recognised as the same user! +# +# idp_name: A user-facing name for this identity provider, which is used to +# offer the user a choice of login mechanisms. +# +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI +# is to upload an image to an (unencrypted) room and then copy the "url" +# from the source of the event.) +# +# discover: set to 'false' to disable the use of the OIDC discovery mechanism +# to discover endpoints. Defaults to true. +# +# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery +# is enabled) to discover the provider's endpoints. +# +# client_id: Required. oauth2 client id to use. +# +# client_secret: Required. oauth2 client secret to use. +# +# client_auth_method: auth method to use when exchanging the token. Valid +# values are 'client_secret_basic' (default), 'client_secret_post' and +# 'none'. +# +# scopes: list of scopes to request. This should normally include the "openid" +# scope. Defaults to ["openid"]. +# +# authorization_endpoint: the oauth2 authorization endpoint. Required if +# provider discovery is disabled. +# +# token_endpoint: the oauth2 token endpoint. Required if provider discovery is +# disabled. +# +# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is +# disabled and the 'openid' scope is not requested. +# +# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and +# the 'openid' scope is used. +# +# skip_verification: set to 'true' to skip metadata verification. Use this if +# you are connecting to a provider that is not OpenID Connect compliant. +# Defaults to false. Avoid this in production. +# +# user_profile_method: Whether to fetch the user profile from the userinfo +# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. +# +# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is +# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the +# userinfo endpoint. +# +# allow_existing_users: set to 'true' to allow a user logging in via OIDC to +# match a pre-existing account instead of failing. This could be used if +# switching from password logins to OIDC. Defaults to false. +# +# user_mapping_provider: Configuration for how attributes returned from a OIDC +# provider are mapped onto a matrix user. This setting has the following +# sub-properties: +# +# module: The class name of a custom mapping module. Default is +# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. +# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers +# for information on implementing a custom mapping provider. +# +# config: Configuration for the mapping provider module. This section will +# be passed as a Python dictionary to the user mapping provider +# module's `parse_config` method. +# +# For the default provider, the following settings are available: +# +# sub: name of the claim containing a unique identifier for the +# user. Defaults to 'sub', which OpenID Connect compliant +# providers should provide. +# +# localpart_template: Jinja2 template for the localpart of the MXID. +# If this is not set, the user will be prompted to choose their +# own username. +# +# display_name_template: Jinja2 template for the display name to set +# on first login. If unset, no displayname will be set. +# +# extra_attributes: a map of Jinja2 templates for extra attributes +# to send back to the client during login. +# Note that these are non-standard and clients will ignore them +# without modifications. +# +# When rendering, the Jinja2 templates are given a 'user' variable, +# which is set to the claims returned by the UserInfo Endpoint and/or +# in the ID Token. +# +# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md +# for information on how to configure these options. +# +# For backwards compatibility, it is also possible to configure a single OIDC +# provider via an 'oidc_config' setting. This is now deprecated and admins are +# advised to migrate to the 'oidc_providers' format. (When doing that migration, +# use 'oidc' for the idp_id to ensure that existing users continue to be +# recognised.) +# +oidc_providers: + # Generic example + # + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true + + # For use with Keycloak + # + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] + + # For use with Github + # + #- idp_id: github + # idp_name: Github + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{ user.login }" + # display_name_template: "{ user.name }" # Enable Central Authentication Service (CAS) for registration and login. @@ -1893,9 +1913,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # @@ -1969,6 +1989,14 @@ sso: # # This template has no additional variables. # + # * HTML page shown after a user-interactive authentication session which + # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. + # + # When rendering, this template is given the following variables: + # * server_name: the homeserver's name. + # * user_id_to_verify: the MXID of the user that we are trying to + # validate. + # # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) # attempts to login: 'sso_account_deactivated.html'. # diff --git a/docs/turn-howto.md b/docs/turn-howto.md index a470c274a5..e8f13ad484 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -232,6 +232,12 @@ Here are a few things to try: (Understanding the output is beyond the scope of this document!) + * You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/. + Note that this test is not fully reliable yet, so don't be discouraged if + the test fails. + [Here](https://github.com/matrix-org/voip-tester) is the github repo of the + source of the tester, where you can file bug reports. + * There is a WebRTC test tool at https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To use it, you will need a username/password for your TURN server. You can diff --git a/docs/workers.md b/docs/workers.md index 298adf8695..d01683681f 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only be used for demo purposes and any admin considering workers should already be running PostgreSQL. +See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability +for a higher level overview. + ## Main process/worker communication The processes communicate with each other via a Synapse-specific protocol called @@ -56,7 +59,7 @@ The appropriate dependencies must also be installed for Synapse. If using a virtualenv, these can be installed with: ```sh -pip install matrix-synapse[redis] +pip install "matrix-synapse[redis]" ``` Note that these dependencies are included when synapse is installed with `pip @@ -214,6 +217,7 @@ expressions: ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$ ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$ + ^/_matrix/client/(api/v1|r0|unstable)/devices$ ^/_matrix/client/(api/v1|r0|unstable)/keys/query$ ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$ ^/_matrix/client/versions$ diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index f328ab57d5..fe2965cd36 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -80,7 +80,8 @@ else # then lint everything! if [[ -z ${files+x} ]]; then # Lint all source code files and directories - files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark") + # Note: this list aims the mirror the one in tox.ini + files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite") fi fi diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22dd169bfb..69bf9110a6 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public"], + "rooms": ["is_public", "has_auth_chain_index"], "event_edges": ["is_state"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], diff --git a/setup.py b/setup.py index 9730afb41b..ddbe9f511a 100755 --- a/setup.py +++ b/setup.py @@ -121,6 +121,7 @@ setup( include_package_data=True, zip_safe=False, long_description=long_description, + long_description_content_type="text/x-rst", python_requires="~=3.5", classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/synapse/__init__.py b/synapse/__init__.py index 99fb675748..d423856d82 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.25.0rc1" +__version__ = "1.26.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 48c4d7b0be..67ecbd32ff 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.appservice import ApplicationService from synapse.events import EventBase +from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing from synapse.storage.databases.main.registration import TokenLookupResult @@ -186,8 +187,8 @@ class Auth: AuthError if access is denied for the user in the access token """ try: - ip_addr = self.hs.get_ip_from_request(request) - user_agent = request.get_user_agent("") + ip_addr = request.getClientIP() + user_agent = get_request_user_agent(request) access_token = self.get_access_token_from_request(request) @@ -275,7 +276,7 @@ class Auth: return None, None if app_service.ip_range_whitelist: - ip_address = IPAddress(self.hs.get_ip_from_request(request)) + ip_address = IPAddress(request.getClientIP()) if ip_address not in app_service.ip_range_whitelist: return None, None diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 6379c86dde..e36aeef31f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,8 +42,6 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.public_baseurl is None: - raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 37ecdbe3d8..395e202b89 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +20,7 @@ import signal import socket import sys import traceback -from typing import Iterable +from typing import Awaitable, Callable, Iterable from typing_extensions import NoReturn @@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn: sys.exit(1) +def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None: + """Register a callback with the reactor, to be called once it is running + + This can be used to initialise parts of the system which require an asynchronous + setup. + + Any exception raised by the callback will be printed and logged, and the process + will exit. + """ + + async def wrapper(): + try: + await cb(*args, **kwargs) + except Exception: + # previously, we used Failure().printTraceback() here, in the hope that + # would give better tracebacks than traceback.print_exc(). However, that + # doesn't handle chained exceptions (with a __cause__ or __context__) well, + # and I *think* the need for Failure() is reduced now that we mostly use + # async/await. + + # Write the exception to both the logs *and* the unredirected stderr, + # because people tend to get confused if it only goes to one or the other. + # + # One problem with this is that if people are using a logging config that + # logs to the console (as is common eg under docker), they will get two + # copies of the exception. We could maybe try to detect that, but it's + # probably a cost we can bear. + logger.fatal("Error during startup", exc_info=True) + print("Error during startup:", file=sys.__stderr__) + traceback.print_exc(file=sys.__stderr__) + + # it's no use calling sys.exit here, since that just raises a SystemExit + # exception which is then caught by the reactor, and everything carries + # on as normal. + os._exit(1) + + reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) + + def listen_metrics(bind_addresses, port): """ Start Prometheus metrics server. @@ -227,7 +267,7 @@ def refresh_certificate(hs): logger.info("Context factories updated.") -def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): +async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): """ Start a Synapse server or worker. @@ -241,75 +281,67 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): hs: homeserver instance listeners: Listener configuration ('listeners' in homeserver.yaml) """ - try: - # Set up the SIGHUP machinery. - if hasattr(signal, "SIGHUP"): + # Set up the SIGHUP machinery. + if hasattr(signal, "SIGHUP"): + reactor = hs.get_reactor() - reactor = hs.get_reactor() + @wrap_as_background_process("sighup") + def handle_sighup(*args, **kwargs): + # Tell systemd our state, if we're using it. This will silently fail if + # we're not using systemd. + sdnotify(b"RELOADING=1") - @wrap_as_background_process("sighup") - def handle_sighup(*args, **kwargs): - # Tell systemd our state, if we're using it. This will silently fail if - # we're not using systemd. - sdnotify(b"RELOADING=1") + for i, args, kwargs in _sighup_callbacks: + i(*args, **kwargs) - for i, args, kwargs in _sighup_callbacks: - i(*args, **kwargs) + sdnotify(b"READY=1") - sdnotify(b"READY=1") + # We defer running the sighup handlers until next reactor tick. This + # is so that we're in a sane state, e.g. flushing the logs may fail + # if the sighup happens in the middle of writing a log entry. + def run_sighup(*args, **kwargs): + # `callFromThread` should be "signal safe" as well as thread + # safe. + reactor.callFromThread(handle_sighup, *args, **kwargs) - # We defer running the sighup handlers until next reactor tick. This - # is so that we're in a sane state, e.g. flushing the logs may fail - # if the sighup happens in the middle of writing a log entry. - def run_sighup(*args, **kwargs): - # `callFromThread` should be "signal safe" as well as thread - # safe. - reactor.callFromThread(handle_sighup, *args, **kwargs) + signal.signal(signal.SIGHUP, run_sighup) - signal.signal(signal.SIGHUP, run_sighup) + register_sighup(refresh_certificate, hs) - register_sighup(refresh_certificate, hs) + # Load the certificate from disk. + refresh_certificate(hs) - # Load the certificate from disk. - refresh_certificate(hs) + # Start the tracer + synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa + hs + ) - # Start the tracer - synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs - ) + # It is now safe to start your Synapse. + hs.start_listening(listeners) + hs.get_datastore().db_pool.start_profiling() + hs.get_pusherpool().start() + + # Log when we start the shut down process. + hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", logger.info, "Shutting down..." + ) - # It is now safe to start your Synapse. - hs.start_listening(listeners) - hs.get_datastore().db_pool.start_profiling() - hs.get_pusherpool().start() + setup_sentry(hs) + setup_sdnotify(hs) - # Log when we start the shut down process. - hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", logger.info, "Shutting down..." - ) + # If background tasks are running on the main process, start collecting the + # phone home stats. + if hs.config.run_background_tasks: + start_phone_stats_home(hs) - setup_sentry(hs) - setup_sdnotify(hs) - - # If background tasks are running on the main process, start collecting the - # phone home stats. - if hs.config.run_background_tasks: - start_phone_stats_home(hs) - - # We now freeze all allocated objects in the hopes that (almost) - # everything currently allocated are things that will be used for the - # rest of time. Doing so means less work each GC (hopefully). - # - # This only works on Python 3.7 - if sys.version_info >= (3, 7): - gc.collect() - gc.freeze() - except Exception: - traceback.print_exc(file=sys.stderr) - reactor = hs.get_reactor() - if reactor.running: - reactor.stop() - sys.exit(1) + # We now freeze all allocated objects in the hopes that (almost) + # everything currently allocated are things that will be used for the + # rest of time. Doing so means less work each GC (hopefully). + # + # This only works on Python 3.7 + if sys.version_info >= (3, 7): + gc.collect() + gc.freeze() def setup_sentry(hs): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 4428472707..e60988fa4a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager -from twisted.internet import address, reactor +from twisted.internet import address import synapse import synapse.events @@ -34,6 +34,7 @@ from synapse.api.urls import ( SERVER_KEY_V2_PREFIX, ) from synapse.app import _base +from synapse.app._base import register_start from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging @@ -99,14 +100,28 @@ from synapse.rest.client.v1.profile import ( ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import groups, sync, user_directory +from synapse.rest.client.v2_alpha import ( + account_data, + groups, + read_marker, + receipts, + room_keys, + sync, + tags, + user_directory, +) from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account_data import ( AccountDataServlet, RoomAccountDataServlet, ) -from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet +from synapse.rest.client.v2_alpha.devices import DevicesRestServlet +from synapse.rest.client.v2_alpha.keys import ( + KeyChangesServlet, + KeyQueryServlet, + OneTimeKeyServlet, +) from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet @@ -115,6 +130,7 @@ from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore +from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( @@ -446,6 +462,7 @@ class GenericWorkerSlavedStore( UserDirectoryStore, StatsStore, UIAuthWorkerStore, + EndToEndRoomKeyStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, @@ -502,7 +519,9 @@ class GenericWorkerServer(HomeServer): RegisterRestServlet(self).register(resource) LoginRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource) + DevicesRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) + OneTimeKeyServlet(self).register(resource) KeyChangesServlet(self).register(resource) VoipRestServlet(self).register(resource) PushRuleRestServlet(self).register(resource) @@ -520,6 +539,11 @@ class GenericWorkerServer(HomeServer): room.register_servlets(self, resource, True) room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) + room_keys.register_servlets(self, resource) + tags.register_servlets(self, resource) + account_data.register_servlets(self, resource) + receipts.register_servlets(self, resource) + read_marker.register_servlets(self, resource) SendToDeviceRestServlet(self).register(resource) @@ -960,9 +984,7 @@ def start(config_options): # streams. Will no-op if no streams can be written to by this worker. hs.get_replication_streamer() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, hs, config.worker_listeners - ) + register_start(_base.start, hs, config.worker_listeners) _base.start_worker_reactor("synapse-generic-worker", config) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b1d9817a6a..57a2f5237c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -15,15 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import logging import os import sys from typing import Iterable, Iterator -from twisted.application import service -from twisted.internet import defer, reactor -from twisted.python.failure import Failure +from twisted.internet import reactor from twisted.web.resource import EncodingResourceWrapper, IResource from twisted.web.server import GzipEncoderFactory from twisted.web.static import File @@ -40,7 +37,7 @@ from synapse.api.urls import ( WEB_CLIENT_PREFIX, ) from synapse.app import _base -from synapse.app._base import listen_ssl, listen_tcp, quit_with_error +from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start from synapse.config._base import ConfigError from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig @@ -73,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module -from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.homeserver") @@ -417,40 +413,28 @@ def setup(config_options): _base.refresh_certificate(hs) async def start(): - try: - # Run the ACME provisioning code, if it's enabled. - if hs.config.acme_enabled: - acme = hs.get_acme_handler() - # Start up the webservices which we will respond to ACME - # challenges with, and then provision. - await acme.start_listening() - await do_acme() + # Run the ACME provisioning code, if it's enabled. + if hs.config.acme_enabled: + acme = hs.get_acme_handler() + # Start up the webservices which we will respond to ACME + # challenges with, and then provision. + await acme.start_listening() + await do_acme() - # Check if it needs to be reprovisioned every day. - hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) + # Check if it needs to be reprovisioned every day. + hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) - # Load the OIDC provider metadatas, if OIDC is enabled. - if hs.config.oidc_enabled: - oidc = hs.get_oidc_handler() - # Loading the provider metadata also ensures the provider config is valid. - await oidc.load_metadata() - await oidc.load_jwks() + # Load the OIDC provider metadatas, if OIDC is enabled. + if hs.config.oidc_enabled: + oidc = hs.get_oidc_handler() + # Loading the provider metadata also ensures the provider config is valid. + await oidc.load_metadata() - _base.start(hs, config.listeners) + await _base.start(hs, config.listeners) - hs.get_datastore().db_pool.updates.start_doing_background_updates() - except Exception: - # Print the exception and bail out. - print("Error during startup:", file=sys.stderr) + hs.get_datastore().db_pool.updates.start_doing_background_updates() - # this gives better tracebacks than traceback.print_exc() - Failure().printTraceback(file=sys.stderr) - - if reactor.running: - reactor.stop() - sys.exit(1) - - reactor.callWhenRunning(lambda: defer.ensureDeferred(start())) + register_start(start) return hs @@ -487,25 +471,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]: e = e.__cause__ -class SynapseService(service.Service): - """ - A twisted Service class that will start synapse. Used to run synapse - via twistd and a .tac. - """ - - def __init__(self, config): - self.config = config - - def startService(self): - hs = setup(self.config) - change_resource_limit(hs.config.soft_file_limit) - if hs.config.gc_thresholds: - gc.set_threshold(*hs.config.gc_thresholds) - - def stopService(self): - return self._port.stopListening() - - def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 2931a88207..94144efc87 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -252,11 +252,12 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update({"format_ts": _format_ts_filter}) - if self.public_baseurl: - env.filters.update( - {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} - ) + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) for filename in filenames: # Load the template diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 2f97e6d258..c7877b4095 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -40,7 +40,7 @@ class CasConfig(Config): self.cas_required_attributes = {} def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """ + return """\ # Enable Central Authentication Service (CAS) for registration and login. # cas_config: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d4328c46b9..6a487afd34 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,11 +166,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - # public_baseurl is required to build password reset and validation links that - # will be emailed to users - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -269,9 +264,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4e3055282d..bfeceeed18 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import string +from collections import Counter +from typing import Iterable, Optional, Tuple, Type + +import attr + +from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements +from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -25,202 +35,442 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): - self.oidc_enabled = False - - oidc_config = config.get("oidc_config") - - if not oidc_config or not oidc_config.get("enabled", False): + self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) + if not self.oidc_providers: return try: check_requirements("oidc") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError(e.message) from e + + # check we don't have any duplicate idp_ids now. (The SSO handler will also + # check for duplicates when the REST listeners get registered, but that happens + # after synapse has forked so doesn't give nice errors.) + c = Counter([i.idp_id for i in self.oidc_providers]) + for idp_id, count in c.items(): + if count > 1: + raise ConfigError( + "Multiple OIDC providers have the idp_id %r." % idp_id + ) public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" - self.oidc_enabled = True - self.oidc_discover = oidc_config.get("discover", True) - self.oidc_issuer = oidc_config["issuer"] - self.oidc_client_id = oidc_config["client_id"] - self.oidc_client_secret = oidc_config["client_secret"] - self.oidc_client_auth_method = oidc_config.get( - "client_auth_method", "client_secret_basic" - ) - self.oidc_scopes = oidc_config.get("scopes", ["openid"]) - self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") - self.oidc_token_endpoint = oidc_config.get("token_endpoint") - self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") - self.oidc_jwks_uri = oidc_config.get("jwks_uri") - self.oidc_skip_verification = oidc_config.get("skip_verification", False) - self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto") - self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) - - ump_config = oidc_config.get("user_mapping_provider", {}) - ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) - ump_config.setdefault("config", {}) - - ( - self.oidc_user_mapping_provider_class, - self.oidc_user_mapping_provider_config, - ) = load_module(ump_config, ("oidc_config", "user_mapping_provider")) - - # Ensure loaded user mapping module has defined all necessary methods - required_methods = [ - "get_remote_user_id", - "map_user_attributes", - ] - missing_methods = [ - method - for method in required_methods - if not hasattr(self.oidc_user_mapping_provider_class, method) - ] - if missing_methods: - raise ConfigError( - "Class specified by oidc_config." - "user_mapping_provider.module is missing required " - "methods: %s" % (", ".join(missing_methods),) - ) + @property + def oidc_enabled(self) -> bool: + # OIDC is enabled if we have a provider + return bool(self.oidc_providers) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. + # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration + # and login. + # + # Options for each entry include: + # + # idp_id: a unique identifier for this identity provider. Used internally + # by Synapse; should be a single word such as 'github'. + # + # Note that, if this is changed, users authenticating via that provider + # will no longer be recognised as the same user! + # + # idp_name: A user-facing name for this identity provider, which is used to + # offer the user a choice of login mechanisms. + # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI + # is to upload an image to an (unencrypted) room and then copy the "url" + # from the source of the event.) + # + # discover: set to 'false' to disable the use of the OIDC discovery mechanism + # to discover endpoints. Defaults to true. + # + # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery + # is enabled) to discover the provider's endpoints. + # + # client_id: Required. oauth2 client id to use. + # + # client_secret: Required. oauth2 client secret to use. + # + # client_auth_method: auth method to use when exchanging the token. Valid + # values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + # scopes: list of scopes to request. This should normally include the "openid" + # scope. Defaults to ["openid"]. + # + # authorization_endpoint: the oauth2 authorization endpoint. Required if + # provider discovery is disabled. + # + # token_endpoint: the oauth2 token endpoint. Required if provider discovery is + # disabled. + # + # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is + # disabled and the 'openid' scope is not requested. + # + # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and + # the 'openid' scope is used. + # + # skip_verification: set to 'true' to skip metadata verification. Use this if + # you are connecting to a provider that is not OpenID Connect compliant. + # Defaults to false. Avoid this in production. + # + # user_profile_method: Whether to fetch the user profile from the userinfo + # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. + # + # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is + # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the + # userinfo endpoint. + # + # allow_existing_users: set to 'true' to allow a user logging in via OIDC to + # match a pre-existing account instead of failing. This could be used if + # switching from password logins to OIDC. Defaults to false. + # + # user_mapping_provider: Configuration for how attributes returned from a OIDC + # provider are mapped onto a matrix user. This setting has the following + # sub-properties: + # + # module: The class name of a custom mapping module. Default is + # {mapping_provider!r}. + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. + # + # config: Configuration for the mapping provider module. This section will + # be passed as a Python dictionary to the user mapping provider + # module's `parse_config` method. + # + # For the default provider, the following settings are available: + # + # sub: name of the claim containing a unique identifier for the + # user. Defaults to 'sub', which OpenID Connect compliant + # providers should provide. + # + # localpart_template: Jinja2 template for the localpart of the MXID. + # If this is not set, the user will be prompted to choose their + # own username. + # + # display_name_template: Jinja2 template for the display name to set + # on first login. If unset, no displayname will be set. + # + # extra_attributes: a map of Jinja2 templates for extra attributes + # to send back to the client during login. + # Note that these are non-standard and clients will ignore them + # without modifications. + # + # When rendering, the Jinja2 templates are given a 'user' variable, + # which is set to the claims returned by the UserInfo Endpoint and/or + # in the ID Token. # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md - # for some example configurations. + # for information on how to configure these options. # - oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. - # - #enabled: true - - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. - # - #discover: false - - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. - # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. + # For backwards compatibility, it is also possible to configure a single OIDC + # provider via an 'oidc_config' setting. This is now deprecated and admins are + # advised to migrate to the 'oidc_providers' format. (When doing that migration, + # use 'oidc' for the idp_id to ensure that existing users continue to be + # recognised.) + # + oidc_providers: + # Generic example # - # Required if 'enabled' is true. + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true + + # For use with Keycloak # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] + + # For use with Github # - #client_auth_method: client_secret_post + #- idp_id: github + # idp_name: Github + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" + """.format( + mapping_provider=DEFAULT_USER_MAPPING_PROVIDER + ) - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" +# jsonschema definition of the configuration settings for an oidc identity provider +OIDC_PROVIDER_CONFIG_SCHEMA = { + "type": "object", + "required": ["issuer", "client_id", "client_secret"], + "properties": { + # TODO: fix the maxLength here depending on what MSC2528 decides + # remember that we prefix the ID given here with `oidc-` + "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, + "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, + "discover": {"type": "boolean"}, + "issuer": {"type": "string"}, + "client_id": {"type": "string"}, + "client_secret": {"type": "string"}, + "client_auth_method": { + "type": "string", + # the following list is the same as the keys of + # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it + # to avoid importing authlib here. + "enum": ["client_secret_basic", "client_secret_post", "none"], + }, + "scopes": {"type": "array", "items": {"type": "string"}}, + "authorization_endpoint": {"type": "string"}, + "token_endpoint": {"type": "string"}, + "userinfo_endpoint": {"type": "string"}, + "jwks_uri": {"type": "string"}, + "skip_verification": {"type": "boolean"}, + "user_profile_method": { + "type": "string", + "enum": ["auto", "userinfo_endpoint"], + }, + "allow_existing_users": {"type": "boolean"}, + "user_mapping_provider": {"type": ["object", "null"]}, + }, +} + +# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name +OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = { + "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}] +} + + +# the `oidc_providers` list can either be None (as it is in the default config), or +# a list of provider configs, each of which requires an explicit ID and name. +OIDC_PROVIDER_LIST_SCHEMA = { + "oneOf": [ + {"type": "null"}, + {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA}, + ] +} + +# the `oidc_config` setting can either be None (which it used to be in the default +# config), or an object. If an object, it is ignored unless it has an "enabled: True" +# property. +# +# It's *possible* to represent this with jsonschema, but the resultant errors aren't +# particularly clear, so we just check for either an object or a null here, and do +# additional checks in the code. +OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} + +# the top-level schema can contain an "oidc_config" and/or an "oidc_providers". +MAIN_CONFIG_SCHEMA = { + "type": "object", + "properties": { + "oidc_config": OIDC_CONFIG_SCHEMA, + "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA, + }, +} + + +def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]: + """extract and parse the OIDC provider configs from the config dict + + The configuration may contain either a single `oidc_config` object with an + `enabled: True` property, or a list of provider configurations under + `oidc_providers`, *or both*. + + Returns a generator which yields the OidcProviderConfig objects + """ + validate_config(MAIN_CONFIG_SCHEMA, config, ()) + + for i, p in enumerate(config.get("oidc_providers") or []): + yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,))) + + # for backwards-compatibility, it is also possible to provide a single "oidc_config" + # object with an "enabled: True" property. + oidc_config = config.get("oidc_config") + if oidc_config and oidc_config.get("enabled", False): + # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that + # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA + # above), so now we need to validate it. + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) + yield _parse_oidc_config_dict(oidc_config, ("oidc_config",)) + + +def _parse_oidc_config_dict( + oidc_config: JsonDict, config_path: Tuple[str, ...] +) -> "OidcProviderConfig": + """Take the configuration dict and parse it into an OidcProviderConfig + + Raises: + ConfigError if the configuration is malformed. + """ + ump_config = oidc_config.get("user_mapping_provider", {}) + ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + ump_config.setdefault("config", {}) + + (user_mapping_provider_class, user_mapping_provider_config,) = load_module( + ump_config, config_path + ("user_mapping_provider",) + ) + + # Ensure loaded user mapping module has defined all necessary methods + required_methods = [ + "get_remote_user_id", + "map_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class %s is missing required " + "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), + config_path + ("user_mapping_provider", "module"), + ) - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" + # MSC2858 will apply certain limits in what can be used as an IdP id, so let's + # enforce those limits now. + # TODO: factor out this stuff to a generic function + idp_id = oidc_config.get("idp_id", "oidc") - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" + # TODO: update this validity check based on what MSC2858 decides. + valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + if any(c not in valid_idp_chars for c in idp_id): + raise ConfigError( + 'idp_id may only contain a-z, 0-9, "-", ".", "_"', + config_path + ("idp_id",), + ) - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true + if idp_id[0] not in string.ascii_lowercase: + raise ConfigError( + "idp_id must start with a-z", config_path + ("idp_id",), + ) - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". - # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. - # - #user_profile_method: "userinfo_endpoint" + # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid + # clashes with other mechs (such as SAML, CAS). + # + # We allow "oidc" as an exception so that people migrating from old-style + # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to + # a new-style "oidc_providers" entry without changing the idp_id for their provider + # (and thereby invalidating their user_external_ids data). - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true + if idp_id != "oidc": + idp_id = "oidc-" + idp_id - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is {mapping_provider!r}. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{{{ user.preferred_username }}}}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{{{ user.birthdate }}}}" - """.format( - mapping_provider=DEFAULT_USER_MAPPING_PROVIDER - ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + + return OidcProviderConfig( + idp_id=idp_id, + idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, + discover=oidc_config.get("discover", True), + issuer=oidc_config["issuer"], + client_id=oidc_config["client_id"], + client_secret=oidc_config["client_secret"], + client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + scopes=oidc_config.get("scopes", ["openid"]), + authorization_endpoint=oidc_config.get("authorization_endpoint"), + token_endpoint=oidc_config.get("token_endpoint"), + userinfo_endpoint=oidc_config.get("userinfo_endpoint"), + jwks_uri=oidc_config.get("jwks_uri"), + skip_verification=oidc_config.get("skip_verification", False), + user_profile_method=oidc_config.get("user_profile_method", "auto"), + allow_existing_users=oidc_config.get("allow_existing_users", False), + user_mapping_provider_class=user_mapping_provider_class, + user_mapping_provider_config=user_mapping_provider_config, + ) + + +@attr.s(slots=True, frozen=True) +class OidcProviderConfig: + # a unique identifier for this identity provider. Used in the 'user_external_ids' + # table, as well as the query/path parameter used in the login protocol. + idp_id = attr.ib(type=str) + + # user-facing name for this identity provider. + idp_name = attr.ib(type=str) + + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + + # whether the OIDC discovery mechanism is used to discover endpoints + discover = attr.ib(type=bool) + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + issuer = attr.ib(type=str) + + # oauth2 client id to use + client_id = attr.ib(type=str) + + # oauth2 client secret to use + client_secret = attr.ib(type=str) + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic', 'client_secret_post' and + # 'none'. + client_auth_method = attr.ib(type=str) + + # list of scopes to request + scopes = attr.ib(type=Collection[str]) + + # the oauth2 authorization endpoint. Required if discovery is disabled. + authorization_endpoint = attr.ib(type=Optional[str]) + + # the oauth2 token endpoint. Required if discovery is disabled. + token_endpoint = attr.ib(type=Optional[str]) + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + userinfo_endpoint = attr.ib(type=Optional[str]) + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + jwks_uri = attr.ib(type=Optional[str]) + + # Whether to skip metadata verification + skip_verification = attr.ib(type=bool) + + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + user_profile_method = attr.ib(type=str) + + # whether to allow a user logging in via OIDC to match a pre-existing account + # instead of failing + allow_existing_users = attr.ib(type=bool) + + # the class of the user mapping provider + user_mapping_provider_class = attr.ib(type=Type) + + # the config of the user mapping provider + user_mapping_provider_config = attr.ib() diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc5f75123c..4bfc69cb7a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -14,14 +14,13 @@ # limitations under the License. import os -from distutils.util import strtobool import pkg_resources from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID -from synapse.util.stringutils import random_string_with_symbols +from synapse.util.stringutils import random_string_with_symbols, strtobool class AccountValidityConfig(Config): @@ -50,10 +49,6 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - template_dir = config.get("template_dir") if not template_dir: @@ -86,12 +81,12 @@ class RegistrationConfig(Config): section = "registration" def read_config(self, config, **kwargs): - self.enable_registration = bool( - strtobool(str(config.get("enable_registration", False))) + self.enable_registration = strtobool( + str(config.get("enable_registration", False)) ) if "disable_registration" in config: - self.enable_registration = not bool( - strtobool(str(config["disable_registration"])) + self.enable_registration = not strtobool( + str(config["disable_registration"]) ) self.account_validity = AccountValidityConfig( @@ -110,13 +105,6 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: - raise ConfigError( - "The configuration option `public_baseurl` is required if " - "`account_threepid_delegate.msisdn` is set, such that " - "clients know where to submit validation tokens to. Please " - "configure `public_baseurl`." - ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -241,8 +229,9 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -333,8 +322,7 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client. - # This setting is ignored unless public_baseurl is also set.) + # (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -359,8 +347,6 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # - # If a delegate is specified, the config option public_baseurl must also be filled out. - # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 7b97d4f114..f33dfa0d6a 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,8 +189,6 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 7242a4aa8e..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError @@ -161,7 +161,11 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") + self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( + self.server_name, + ) + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -317,9 +321,6 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -740,11 +741,16 @@ class ServerConfig(Config): # #web_client_location: https://riot.example.com/ - # The public-facing base URL that clients use to access this HS - # (not including _matrix/...). This is the same URL a user would - # enter into the 'custom HS URL' field on their client. If you - # use synapse with a reverse proxy, this should be the URL to reach - # synapse via the proxy. + # The public-facing base URL that clients use to access this Homeserver (not + # including _matrix/...). This is the same URL a user might enter into the + # 'Custom Homeserver URL' field on their client. If you use Synapse with a + # reverse proxy, this should be the URL to reach Synapse via the proxy. + # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see + # 'listeners' below). + # + # If this is left unset, it defaults to 'https://<server_name>/'. (Note that + # that will not work unless you configure Synapse or a reverse-proxy to listen + # on port 443.) # #public_baseurl: https://example.com/ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 1aeb1c5c92..59be825532 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -37,6 +37,7 @@ class SSOConfig(Config): self.sso_error_template, sso_account_deactivated_template, sso_auth_success_template, + self.sso_auth_bad_user_template, ) = self.read_templates( [ "sso_login_idp_picker.html", @@ -45,6 +46,7 @@ class SSOConfig(Config): "sso_error.html", "sso_account_deactivated.html", "sso_auth_success.html", + "sso_auth_bad_user.html", ], template_dir, ) @@ -62,11 +64,8 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - # public_baseurl is an optional setting, so we only add the fallback's URL to the - # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -84,9 +83,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # @@ -160,6 +159,14 @@ class SSOConfig(Config): # # This template has no additional variables. # + # * HTML page shown after a user-interactive authentication session which + # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. + # + # When rendering, this template is given the following variables: + # * server_name: the homeserver's name. + # * user_id_to_verify: the MXID of the user that we are trying to + # validate. + # # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) # attempts to login: 'sso_account_deactivated.html'. # diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 364583f48b..f10e33f7b8 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -56,6 +56,12 @@ class WriterLocations: to_device = attr.ib( default=["master"], type=List[str], converter=_instance_to_list_converter, ) + account_data = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) + receipts = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -127,7 +133,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing", "to_device"): + for stream in ("events", "typing", "to_device", "account_data", "receipts"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -141,6 +147,16 @@ class WorkerConfig(Config): "Must only specify one instance to handle `to_device` messages." ) + if len(self.writers.account_data) != 1: + raise ConfigError( + "Must only specify one instance to handle `account_data` messages." + ) + + if len(self.writers.receipts) != 1: + raise ConfigError( + "Must only specify one instance to handle `receipts` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8028663fa8..3ec4120f85 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -17,7 +17,6 @@ import abc import os -from distutils.util import strtobool from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze +from synapse.util.stringutils import strtobool # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting a @@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze # NOTE: This is overridden by the configuration by the Synapse worker apps, but # for the sake of tests, it is set here while it cannot be configured on the # homeserver object itself. + USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 302b2f69bc..d330ae5dbc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,6 +18,7 @@ import copy import itertools import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -26,7 +27,6 @@ from typing import ( List, Mapping, Optional, - Sequence, Tuple, TypeVar, Union, @@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) @@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.pdu_destination_tried = {} + self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -116,33 +119,32 @@ class FederationClient(FederationBase): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query( + async def make_query( self, - destination, - query_type, - args, - retry_on_dns_fail=False, - ignore_backoff=False, - ): + destination: str, + query_type: str, + args: dict, + retry_on_dns_fail: bool = False, + ignore_backoff: bool = False, + ) -> JsonDict: """Sends a federation Query to a remote homeserver of the given type and arguments. Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the + destination: Domain name of the remote homeserver + query_type: Category of the query type; should match the handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details + args: Mapping of strings to strings containing the details of the query request. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - a Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels(query_type).inc() - return self.transport_layer.make_query( + return await self.transport_layer.make_query( destination, query_type, args, @@ -151,42 +153,52 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content, timeout): + async def query_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> JsonDict: """Query device keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys(destination, content, timeout) + return await self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def query_user_devices(self, destination, user_id, timeout=30000): + async def query_user_devices( + self, destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: """Query the device keys for a list of user ids hosted on a remote server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices(destination, user_id, timeout) + return await self.transport_layer.query_user_devices( + destination, user_id, timeout + ) @log_function - def claim_client_keys(self, destination, content, timeout): + async def claim_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys(destination, content, timeout) + return await self.transport_layer.claim_client_keys( + destination, content, timeout + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] @@ -195,10 +207,10 @@ class FederationClient(FederationBase): given destination server. Args: - dest (str): The remote homeserver to ask. - room_id (str): The room_id to backfill. - limit (int): The maximum number of events to return. - extremities (list): our current backwards extremities, to backfill from + dest: The remote homeserver to ask. + room_id: The room_id to backfill. + limit: The maximum number of events to return. + extremities: our current backwards extremities, to backfill from """ logger.debug("backfill extrem=%s", extremities) @@ -370,7 +382,7 @@ class FederationClient(FederationBase): for events that have failed their checks Returns: - Deferred : A list of PDUs that have valid signatures and hashes. + A list of PDUs that have valid signatures and hashes. """ deferreds = self._check_sigs_and_hashes(room_version, pdus) @@ -418,7 +430,9 @@ class FederationClient(FederationBase): else: return [p for p in valid_pdus if p] - async def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth( + self, destination: str, room_id: str, event_id: str + ) -> List[EventBase]: res = await self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = await self.store.get_room_version(room_id) @@ -700,18 +714,16 @@ class FederationClient(FederationBase): return await self._try_destination_list("send_join", destinations, send_request) - async def _do_send_join(self, destination: str, pdu: EventBase): + async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_join_v2( + return await self.transport_layer.send_join_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -769,7 +781,7 @@ class FederationClient(FederationBase): time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_invite_v2( + return await self.transport_layer.send_invite_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -779,7 +791,6 @@ class FederationClient(FederationBase): "invite_room_state": pdu.unsigned.get("invite_room_state", []), }, ) - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -842,18 +853,16 @@ class FederationClient(FederationBase): "send_leave", destinations, send_request ) - async def _do_send_leave(self, destination, pdu): + async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_leave_v2( + return await self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -879,7 +888,7 @@ class FederationClient(FederationBase): # content. return resp[1] - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -887,7 +896,7 @@ class FederationClient(FederationBase): search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ): + ) -> JsonDict: """Get the list of public rooms from a remote homeserver Args: @@ -901,8 +910,7 @@ class FederationClient(FederationBase): party instance Returns: - Awaitable[Dict[str, Any]]: The response from the remote server, or None if - `remote_server` is the same as the local server_name + The response from the remote server. Raises: HttpResponseException: There was an exception returned from the remote server @@ -910,7 +918,7 @@ class FederationClient(FederationBase): requests over federation """ - return self.transport_layer.get_public_rooms( + return await self.transport_layer.get_public_rooms( remote_server, limit, since_token, @@ -923,7 +931,7 @@ class FederationClient(FederationBase): self, destination: str, room_id: str, - earliest_events_ids: Sequence[str], + earliest_events_ids: Iterable[str], latest_events: Iterable[EventBase], limit: int, min_depth: int, @@ -974,7 +982,9 @@ class FederationClient(FederationBase): return signed_events - async def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite( + self, destinations: Iterable[str], room_id: str, event_dict: JsonDict + ) -> None: for destination in destinations: if destination == self.server_name: continue @@ -983,7 +993,7 @@ class FederationClient(FederationBase): await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) - return None + return except CodeMessageException: raise except Exception as e: @@ -995,7 +1005,7 @@ class FederationClient(FederationBase): async def get_room_complexity( self, destination: str, room_id: str - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """ Fetch the complexity of a remote room from another server. @@ -1008,10 +1018,9 @@ class FederationClient(FederationBase): could not fetch the complexity. """ try: - complexity = await self.transport_layer.get_room_complexity( + return await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 341135822e..b1a5df9638 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +13,157 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import TYPE_CHECKING, List, Tuple +from synapse.replication.http.account_data import ( + ReplicationAddTagRestServlet, + ReplicationRemoveTagRestServlet, + ReplicationRoomAccountDataRestServlet, + ReplicationUserAccountDataRestServlet, +) from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.app.homeserver import HomeServer +class AccountDataHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() + + self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) + self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) + self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) + self._account_data_writers = hs.config.worker.writers.account_data + + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_to_room( + user_id, room_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + + return max_stream_id + else: + response = await self._room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_for_user( + user_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: + """Add a tag to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_tag_to_room( + user_id, room_id, tag, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._add_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + content=content, + ) + return response["max_stream_id"] + + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: + """Remove a tag from a room for a user. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_tag_from_room( + user_id, room_id, tag + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._remove_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + ) + return response["max_stream_id"] + + class AccountDataEventSource: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f4434673dc..0e98db22b3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -49,8 +49,13 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS +from synapse.handlers._base import BaseHandler +from synapse.handlers.ui_auth import ( + INTERACTIVE_AUTH_CHECKERS, + UIAuthSessionDataConstants, +) from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http import get_request_user_agent from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread @@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -260,10 +263,6 @@ class AuthHandler(BaseHandler): # authenticating for an operation to occur on their account. self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template - # The following template is shown after a successful user interactive - # authentication session. It tells the user they can close the window. - self._sso_auth_success_template = hs.config.sso_auth_success_template - # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( @@ -284,7 +283,6 @@ class AuthHandler(BaseHandler): requester: Requester, request: SynapseRequest, request_body: Dict[str, Any], - clientip: str, description: str, ) -> Tuple[dict, Optional[str]]: """ @@ -301,8 +299,6 @@ class AuthHandler(BaseHandler): request_body: The body of the request sent by the client - clientip: The IP address of the client. - description: A human readable string to be displayed to the user that describes the operation happening on their account. @@ -338,10 +334,10 @@ class AuthHandler(BaseHandler): request_body.pop("auth", None) return request_body, None - user_id = requester.user.to_string() + requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) + self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -349,13 +345,16 @@ class AuthHandler(BaseHandler): ) flows = [[login_type] for login_type in supported_ui_auth_types] + def get_new_session_data() -> JsonDict: + return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id} + try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, clientip, description + flows, request, request_body, description, get_new_session_data, ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(user_id) + self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) raise # find the completed login type @@ -363,14 +362,14 @@ class AuthHandler(BaseHandler): if login_type not in result: continue - user_id = result[login_type] + validated_user_id = result[login_type] break else: # this can't happen raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token - if user_id != requester.user.to_string(): + if validated_user_id != requester_user_id: raise AuthError(403, "Invalid auth") # Note that the access token has been validated. @@ -402,13 +401,9 @@ class AuthHandler(BaseHandler): # if sso is enabled, allow the user to log in via SSO iff they have a mapping # from sso to mxid. - if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: - if await self.store.get_external_ids_by_user(user.to_string()): - ui_auth_types.add(LoginType.SSO) - - # Our CAS impl does not (yet) correctly register users in user_external_ids, - # so always offer that if it's available. - if self.hs.config.cas.cas_enabled: + if await self.hs.get_sso_handler().get_identity_providers_for_user( + user.to_string() + ): ui_auth_types.add(LoginType.SSO) return ui_auth_types @@ -426,8 +421,8 @@ class AuthHandler(BaseHandler): flows: List[List[str]], request: SynapseRequest, clientdict: Dict[str, Any], - clientip: str, description: str, + get_new_session_data: Optional[Callable[[], JsonDict]] = None, ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -448,11 +443,16 @@ class AuthHandler(BaseHandler): clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. - clientip: The IP address of the client. - description: A human readable string to be displayed to the user that describes the operation happening on their account. + get_new_session_data: + an optional callback which will be called when starting a new session. + it should return data to be stored as part of the session. + + The keys of the returned data should be entries in + UIAuthSessionDataConstants. + Returns: A tuple of (creds, params, session_id). @@ -480,10 +480,15 @@ class AuthHandler(BaseHandler): # If there's no session ID, create a new session. if not sid: + new_session_data = get_new_session_data() if get_new_session_data else {} + session = await self.store.create_ui_auth_session( clientdict, uri, method, description ) + for k, v in new_session_data.items(): + await self.set_session_data(session.session_id, k, v) + else: try: session = await self.store.get_ui_auth_session(sid) @@ -539,7 +544,8 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) - user_agent = request.get_user_agent("") + user_agent = get_request_user_agent(request) + clientip = request.getClientIP() await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip @@ -644,7 +650,8 @@ class AuthHandler(BaseHandler): Args: session_id: The ID of this session as returned from check_auth - key: The key to store the data under + key: The key to store the data under. An entry from + UIAuthSessionDataConstants. value: The data to store """ try: @@ -660,7 +667,8 @@ class AuthHandler(BaseHandler): Args: session_id: The ID of this session as returned from check_auth - key: The key to store the data under + key: The key the data was stored under. An entry from + UIAuthSessionDataConstants. default: Value to return if the key has not been set """ try: @@ -1334,12 +1342,12 @@ class AuthHandler(BaseHandler): else: return False - async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str: """ Get the HTML for the SSO redirect confirmation page. Args: - redirect_url: The URL to redirect to the SSO provider. + request: The incoming HTTP request session_id: The user interactive authentication session ID. Returns: @@ -1349,30 +1357,38 @@ class AuthHandler(BaseHandler): session = await self.store.get_ui_auth_session(session_id) except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - return self._sso_auth_confirm_template.render( - description=session.description, redirect_url=redirect_url, + + user_id_to_verify = await self.get_session_data( + session_id, UIAuthSessionDataConstants.REQUEST_USER_ID + ) # type: str + + idps = await self.hs.get_sso_handler().get_identity_providers_for_user( + user_id_to_verify ) - async def complete_sso_ui_auth( - self, registered_user_id: str, session_id: str, request: Request, - ): - """Having figured out a mxid for this user, complete the HTTP request + if not idps: + # we checked that the user had some remote identities before offering an SSO + # flow, so either it's been deleted or the client has requested SSO despite + # it not being offered. + raise SynapseError(400, "User has no SSO identities") - Args: - registered_user_id: The registered user ID to complete SSO login for. - session_id: The ID of the user-interactive auth session. - request: The request to complete. - """ - # Mark the stage of the authentication as successful. - # Save the user who authenticated with SSO, this will be used to ensure - # that the account be modified is also the person who logged in. - await self.store.mark_ui_auth_stage_complete( - session_id, LoginType.SSO, registered_user_id + # for now, just pick one + idp_id, sso_auth_provider = next(iter(idps.items())) + if len(idps) > 0: + logger.warning( + "User %r has previously logged in with multiple SSO IdPs; arbitrarily " + "picking %r", + user_id_to_verify, + idp_id, + ) + + redirect_url = await sso_auth_provider.handle_redirect_request( + request, None, session_id ) - # Render the HTML and return. - html = self._sso_auth_success_template - respond_with_html(request, 200, html) + return self._sso_auth_confirm_template.render( + description=session.description, redirect_url=redirect_url, + ) async def complete_sso_login( self, @@ -1488,8 +1504,8 @@ class AuthHandler(BaseHandler): @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({param_name: param}) + query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) + query.append((param_name, param)) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e808142365..c4a3b26a84 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler @@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_identity_handler() + self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname @@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled async def deactivate_account( - self, user_id: str, erase_data: bool, id_server: Optional[str] = None + self, + user_id: str, + erase_data: bool, + requester: Requester, + id_server: Optional[str] = None, + by_admin: bool = False, ) -> bool: """Deactivate a user's account Args: user_id: ID of user to be deactivated erase_data: whether to GDPR-erase the user's data + requester: The user attempting to make this change. id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). + by_admin: Whether this change was made by an administrator. Returns: True if identity server supports removing threepids, otherwise False. @@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler): # Mark the user as erased, if they asked for that if erase_data: + user = UserID.from_string(user_id) + # Remove avatar URL from this user + await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + # Remove displayname from this user + await self._profile_handler.set_displayname(user, requester, "", by_admin) + logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index fc974a82e8..0c7737e09d 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -163,7 +163,7 @@ class DeviceMessageHandler: await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background - run_in_background(self._user_device_resync, sender_user_id) + run_in_background(self._user_device_resync, user_id=sender_user_id) async def send_device_message( self, diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c05036ad1f..f61844d688 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - assert self.hs.config.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 6835c6c462..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect import logging -from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar from urllib.parse import urlencode import attr @@ -35,7 +35,7 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.handlers._base import BaseHandler +from synapse.config.oidc_config import OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -71,6 +71,144 @@ JWK = Dict[str, str] JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class OidcHandler: + """Handles requests related to the OpenID Connect login flow. + """ + + def __init__(self, hs: "HomeServer"): + self._sso_handler = hs.get_sso_handler() + + provider_confs = hs.config.oidc.oidc_providers + # we should not have been instantiated if there is no configured provider. + assert provider_confs + + self._token_generator = OidcSessionTokenGenerator(hs) + self._providers = { + p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs + } # type: Dict[str, OidcProvider] + + async def load_metadata(self) -> None: + """Validate the config and load the metadata from the remote endpoint. + + Called at startup to ensure we have everything we need. + """ + for idp_id, p in self._providers.items(): + try: + await p.load_metadata() + await p.load_jwks() + except Exception as e: + raise Exception( + "Error while initialising OIDC provider %r" % (idp_id,) + ) from e + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._sso_handler.render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + + Once we know the session is legit, we then delegate to the OIDC Provider + implementation, which will exchange the code with the provider and complete the + login/authentication. + + Args: + request: the incoming request from the browser. + """ + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + # error response from the auth server. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthError + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + if error != "access_denied": + logger.error("Error from the OIDC provider: %s %s", error, description) + + self._sso_handler.render_error(request, error, description) + return + + # otherwise, it is presumably a successful response. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2 + + # Fetch the session cookie + session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] + if session is None: + logger.info("No session cookie found") + self._sso_handler.render_error( + request, "missing_session", "No session cookie found" + ) + return + + # Remove the cookie. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing it early avoids spamming the provider with token requests. + request.addCookie( + SESSION_COOKIE_NAME, + b"", + path="/_synapse/oidc", + expires="Thu, Jan 01 1970 00:00:00 UTC", + httpOnly=True, + sameSite="lax", + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("State parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "State parameter is missing" + ) + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + session_data = self._token_generator.verify_oidc_session_token( + session, state + ) + except (MacaroonDeserializationException, ValueError) as e: + logger.exception("Invalid session") + self._sso_handler.render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session") + self._sso_handler.render_error(request, "mismatching_session", str(e)) + return + + oidc_provider = self._providers.get(session_data.idp_id) + if not oidc_provider: + logger.error("OIDC session uses unknown IdP %r", oidc_provider) + self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP") + return + + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "Code parameter is missing" + ) + return + + code = request.args[b"code"][0].decode() + + await oidc_provider.handle_oidc_callback(request, session_data, code) + + class OidcError(Exception): """Used to catch errors when calling the token_endpoint """ @@ -85,44 +223,56 @@ class OidcError(Exception): return self.error -class OidcHandler(BaseHandler): - """Handles requests related to the OpenID Connect login flow. +class OidcProvider: + """Wraps the config for a single OIDC IdentityProvider + + Provides methods for handling redirect requests and callbacks via that particular + IdP. """ - def __init__(self, hs: "HomeServer"): - super().__init__(hs) + def __init__( + self, + hs: "HomeServer", + token_generator: "OidcSessionTokenGenerator", + provider: OidcProviderConfig, + ): + self._store = hs.get_datastore() + + self._token_generator = token_generator + self._callback_url = hs.config.oidc_callback_url # type: str - self._scopes = hs.config.oidc_scopes # type: List[str] - self._user_profile_method = hs.config.oidc_user_profile_method # type: str + + self._scopes = provider.scopes + self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - hs.config.oidc_client_id, - hs.config.oidc_client_secret, - hs.config.oidc_client_auth_method, + provider.client_id, provider.client_secret, provider.client_auth_method, ) # type: ClientAuth - self._client_auth_method = hs.config.oidc_client_auth_method # type: str + self._client_auth_method = provider.client_auth_method self._provider_metadata = OpenIDProviderMetadata( - issuer=hs.config.oidc_issuer, - authorization_endpoint=hs.config.oidc_authorization_endpoint, - token_endpoint=hs.config.oidc_token_endpoint, - userinfo_endpoint=hs.config.oidc_userinfo_endpoint, - jwks_uri=hs.config.oidc_jwks_uri, + issuer=provider.issuer, + authorization_endpoint=provider.authorization_endpoint, + token_endpoint=provider.token_endpoint, + userinfo_endpoint=provider.userinfo_endpoint, + jwks_uri=provider.jwks_uri, ) # type: OpenIDProviderMetadata - self._provider_needs_discovery = hs.config.oidc_discover # type: bool - self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( - hs.config.oidc_user_mapping_provider_config - ) # type: OidcMappingProvider - self._skip_verification = hs.config.oidc_skip_verification # type: bool - self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool + self._provider_needs_discovery = provider.discover + self._user_mapping_provider = provider.user_mapping_provider_class( + provider.user_mapping_provider_config + ) + self._skip_verification = provider.skip_verification + self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() self._server_name = hs.config.server_name # type: str - self._macaroon_secret_key = hs.config.macaroon_secret_key # identifier for the external_ids table - self.idp_id = "oidc" + self.idp_id = provider.idp_id # user-facing name of this auth provider - self.idp_name = "OIDC" + self.idp_name = provider.idp_name + + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon self._sso_handler = hs.get_sso_handler() @@ -519,11 +669,14 @@ class OidcHandler(BaseHandler): if not client_redirect_url: client_redirect_url = b"" - cookie = self._generate_oidc_session_token( + cookie = self._token_generator.generate_oidc_session_token( state=state, - nonce=nonce, - client_redirect_url=client_redirect_url.decode(), - ui_auth_session_id=ui_auth_session_id, + session_data=OidcSessionData( + idp_id=self.idp_id, + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id, + ), ) request.addCookie( SESSION_COOKIE_NAME, @@ -546,22 +699,16 @@ class OidcHandler(BaseHandler): nonce=nonce, ) - async def handle_oidc_callback(self, request: SynapseRequest) -> None: + async def handle_oidc_callback( + self, request: SynapseRequest, session_data: "OidcSessionData", code: str + ) -> None: """Handle an incoming request to /_synapse/oidc/callback - Since we might want to display OIDC-related errors in a user-friendly - way, we don't raise SynapseError from here. Instead, we call - ``self._sso_handler.render_error`` which displays an HTML page for the error. + By this time we have already validated the session on the synapse side, and + now need to do the provider-specific operations. This includes: - Most of the OpenID Connect logic happens here: - - - first, we check if there was any error returned by the provider and - display it - - then we fetch the session cookie, decode and verify it - - the ``state`` query parameter should match with the one stored in the - session cookie - - once we known this session is legit, exchange the code with the - provider using the ``token_endpoint`` (see ``_exchange_code``) + - exchange the code with the provider using the ``token_endpoint`` (see + ``_exchange_code``) - once we have the token, use it to either extract the UserInfo from the ``id_token`` (``_parse_id_token``), or use the ``access_token`` to fetch UserInfo from the ``userinfo_endpoint`` @@ -571,88 +718,12 @@ class OidcHandler(BaseHandler): Args: request: the incoming request from the browser. + session_data: the session data, extracted from our cookie + code: The authorization code we got from the callback. """ - - # The provider might redirect with an error. - # In that case, just display it as-is. - if b"error" in request.args: - # error response from the auth server. see: - # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 - # https://openid.net/specs/openid-connect-core-1_0.html#AuthError - error = request.args[b"error"][0].decode() - description = request.args.get(b"error_description", [b""])[0].decode() - - # Most of the errors returned by the provider could be due by - # either the provider misbehaving or Synapse being misconfigured. - # The only exception of that is "access_denied", where the user - # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) - - self._sso_handler.render_error(request, error, description) - return - - # otherwise, it is presumably a successful response. see: - # https://tools.ietf.org/html/rfc6749#section-4.1.2 - - # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] - if session is None: - logger.info("No session cookie found") - self._sso_handler.render_error( - request, "missing_session", "No session cookie found" - ) - return - - # Remove the cookie. There is a good chance that if the callback failed - # once, it will fail next time and the code will already be exchanged. - # Removing it early avoids spamming the provider with token requests. - request.addCookie( - SESSION_COOKIE_NAME, - b"", - path="/_synapse/oidc", - expires="Thu, Jan 01 1970 00:00:00 UTC", - httpOnly=True, - sameSite="lax", - ) - - # Check for the state query parameter - if b"state" not in request.args: - logger.info("State parameter is missing") - self._sso_handler.render_error( - request, "invalid_request", "State parameter is missing" - ) - return - - state = request.args[b"state"][0].decode() - - # Deserialize the session token and verify it. - try: - ( - nonce, - client_redirect_url, - ui_auth_session_id, - ) = self._verify_oidc_session_token(session, state) - except MacaroonDeserializationException as e: - logger.exception("Invalid session") - self._sso_handler.render_error(request, "invalid_session", str(e)) - return - except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") - self._sso_handler.render_error(request, "mismatching_session", str(e)) - return - # Exchange the code with the provider - if b"code" not in request.args: - logger.info("Code parameter is missing") - self._sso_handler.render_error( - request, "invalid_request", "Code parameter is missing" - ) - return - - logger.debug("Exchanging code") - code = request.args[b"code"][0].decode() try: + logger.debug("Exchanging code") token = await self._exchange_code(code) except OidcError as e: logger.exception("Could not exchange code") @@ -674,14 +745,14 @@ class OidcHandler(BaseHandler): else: logger.debug("Extracting userinfo from id_token") try: - userinfo = await self._parse_id_token(token, nonce=nonce) + userinfo = await self._parse_id_token(token, nonce=session_data.nonce) except Exception as e: logger.exception("Invalid id_token") self._sso_handler.render_error(request, "invalid_token", str(e)) return # first check if we're doing a UIA - if ui_auth_session_id: + if session_data.ui_auth_session_id: try: remote_user_id = self._remote_id_from_userinfo(userinfo) except Exception as e: @@ -690,7 +761,7 @@ class OidcHandler(BaseHandler): return return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, remote_user_id, ui_auth_session_id, request + self.idp_id, remote_user_id, session_data.ui_auth_session_id, request ) # otherwise, it's a login @@ -698,133 +769,12 @@ class OidcHandler(BaseHandler): # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, client_redirect_url + userinfo, token, request, session_data.client_redirect_url ) except MappingException as e: logger.exception("Could not map user") self._sso_handler.render_error(request, "mapping_error", str(e)) - def _generate_oidc_session_token( - self, - state: str, - nonce: str, - client_redirect_url: str, - ui_auth_session_id: Optional[str], - duration_in_ms: int = (60 * 60 * 1000), - ) -> str: - """Generates a signed token storing data about an OIDC session. - - When Synapse initiates an authorization flow, it creates a random state - and a random nonce. Those parameters are given to the provider and - should be verified when the client comes back from the provider. - It is also used to store the client_redirect_url, which is used to - complete the SSO login flow. - - Args: - state: The ``state`` parameter passed to the OIDC provider. - nonce: The ``nonce`` parameter passed to the OIDC provider. - client_redirect_url: The URL the client gave when it initiated the - flow. - ui_auth_session_id: The session ID of the ongoing UI Auth (or - None if this is a login). - duration_in_ms: An optional duration for the token in milliseconds. - Defaults to an hour. - - Returns: - A signed macaroon token with the session information. - """ - macaroon = pymacaroons.Macaroon( - location=self._server_name, identifier="key", key=self._macaroon_secret_key, - ) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("type = session") - macaroon.add_first_party_caveat("state = %s" % (state,)) - macaroon.add_first_party_caveat("nonce = %s" % (nonce,)) - macaroon.add_first_party_caveat( - "client_redirect_url = %s" % (client_redirect_url,) - ) - if ui_auth_session_id: - macaroon.add_first_party_caveat( - "ui_auth_session_id = %s" % (ui_auth_session_id,) - ) - now = self.clock.time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - - return macaroon.serialize() - - def _verify_oidc_session_token( - self, session: bytes, state: str - ) -> Tuple[str, str, Optional[str]]: - """Verifies and extract an OIDC session token. - - This verifies that a given session token was issued by this homeserver - and extract the nonce and client_redirect_url caveats. - - Args: - session: The session token to verify - state: The state the OIDC provider gave back - - Returns: - The nonce, client_redirect_url, and ui_auth_session_id for this session - """ - macaroon = pymacaroons.Macaroon.deserialize(session) - - v = pymacaroons.Verifier() - v.satisfy_exact("gen = 1") - v.satisfy_exact("type = session") - v.satisfy_exact("state = %s" % (state,)) - v.satisfy_general(lambda c: c.startswith("nonce = ")) - v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) - # Sometimes there's a UI auth session ID, it seems to be OK to attempt - # to always satisfy this. - v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) - v.satisfy_general(self._verify_expiry) - - v.verify(macaroon, self._macaroon_secret_key) - - # Extract the `nonce`, `client_redirect_url`, and maybe the - # `ui_auth_session_id` from the token. - nonce = self._get_value_from_macaroon(macaroon, "nonce") - client_redirect_url = self._get_value_from_macaroon( - macaroon, "client_redirect_url" - ) - try: - ui_auth_session_id = self._get_value_from_macaroon( - macaroon, "ui_auth_session_id" - ) # type: Optional[str] - except ValueError: - ui_auth_session_id = None - - return nonce, client_redirect_url, ui_auth_session_id - - def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: - """Extracts a caveat value from a macaroon token. - - Args: - macaroon: the token - key: the key of the caveat to extract - - Returns: - The extracted value - - Raises: - Exception: if the caveat was not in the macaroon - """ - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - def _verify_expiry(self, caveat: str) -> bool: - prefix = "time < " - if not caveat.startswith(prefix): - return False - expiry = int(caveat[len(prefix) :]) - now = self.clock.time_msec() - return now < expiry - async def _complete_oidc_login( self, userinfo: UserInfo, @@ -901,8 +851,8 @@ class OidcHandler(BaseHandler): # and attempt to match it. attributes = await oidc_response_to_user_attributes(failures=0) - user_id = UserID(attributes.localpart, self.server_name).to_string() - users = await self.store.get_users_by_id_case_insensitive(user_id) + user_id = UserID(attributes.localpart, self._server_name).to_string() + users = await self._store.get_users_by_id_case_insensitive(user_id) if users: # If an existing matrix ID is returned, then use it. if len(users) == 1: @@ -954,6 +904,157 @@ class OidcHandler(BaseHandler): return str(remote_user_id) +class OidcSessionTokenGenerator: + """Methods for generating and checking OIDC Session cookies.""" + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._server_name = hs.hostname + self._macaroon_secret_key = hs.config.key.macaroon_secret_key + + def generate_oidc_session_token( + self, + state: str, + session_data: "OidcSessionData", + duration_in_ms: int = (60 * 60 * 1000), + ) -> str: + """Generates a signed token storing data about an OIDC session. + + When Synapse initiates an authorization flow, it creates a random state + and a random nonce. Those parameters are given to the provider and + should be verified when the client comes back from the provider. + It is also used to store the client_redirect_url, which is used to + complete the SSO login flow. + + Args: + state: The ``state`` parameter passed to the OIDC provider. + session_data: data to include in the session token. + duration_in_ms: An optional duration for the token in milliseconds. + Defaults to an hour. + + Returns: + A signed macaroon token with the session information. + """ + macaroon = pymacaroons.Macaroon( + location=self._server_name, identifier="key", key=self._macaroon_secret_key, + ) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = session") + macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,)) + macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,)) + macaroon.add_first_party_caveat( + "client_redirect_url = %s" % (session_data.client_redirect_url,) + ) + if session_data.ui_auth_session_id: + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) + ) + now = self._clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + + return macaroon.serialize() + + def verify_oidc_session_token( + self, session: bytes, state: str + ) -> "OidcSessionData": + """Verifies and extract an OIDC session token. + + This verifies that a given session token was issued by this homeserver + and extract the nonce and client_redirect_url caveats. + + Args: + session: The session token to verify + state: The state the OIDC provider gave back + + Returns: + The data extracted from the session cookie + + Raises: + ValueError if an expected caveat is missing from the macaroon. + """ + macaroon = pymacaroons.Macaroon.deserialize(session) + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = session") + v.satisfy_exact("state = %s" % (state,)) + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("idp_id = ")) + v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + # Sometimes there's a UI auth session ID, it seems to be OK to attempt + # to always satisfy this. + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + v.satisfy_general(self._verify_expiry) + + v.verify(macaroon, self._macaroon_secret_key) + + # Extract the session data from the token. + nonce = self._get_value_from_macaroon(macaroon, "nonce") + idp_id = self._get_value_from_macaroon(macaroon, "idp_id") + client_redirect_url = self._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + try: + ui_auth_session_id = self._get_value_from_macaroon( + macaroon, "ui_auth_session_id" + ) # type: Optional[str] + except ValueError: + ui_auth_session_id = None + + return OidcSessionData( + nonce=nonce, + idp_id=idp_id, + client_redirect_url=client_redirect_url, + ui_auth_session_id=ui_auth_session_id, + ) + + def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + ValueError: if the caveat was not in the macaroon + """ + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + def _verify_expiry(self, caveat: str) -> bool: + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + now = self._clock.time_msec() + return now < expiry + + +@attr.s(frozen=True, slots=True) +class OidcSessionData: + """The attributes which are stored in a OIDC session cookie""" + + # the Identity Provider being used + idp_id = attr.ib(type=str) + + # The `nonce` parameter passed to the OIDC provider. + nonce = attr.ib(type=str) + + # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) + client_redirect_url = attr.ib(type=str) + + # The session ID of the ongoing UI Auth (None if this is a login) + ui_auth_session_id = attr.ib(type=Optional[str], default=None) + + UserAttributeDict = TypedDict( "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36f9ee4b71..c02b951031 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + avatar_url_to_set = new_avatar_url # type: Optional[str] + if new_avatar_url == "": + avatar_url_to_set = None + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url( + target_user.localpart, avatar_url_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index a7550806e6..6bb2fd936b 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler): super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() + self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") - self.notifier = hs.get_notifier() async def received_client_read_marker( self, room_id: str, user_id: str, event_id: str @@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler): if should_update: content = {"event_id": event_id} - max_id = await self.store.add_account_data_to_room( + await self.account_data_handler.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a9abdf42e0..cc21fc2284 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.receipt", self._received_remote_receipt - ) + + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the receipts stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + + # If we can handle the receipt EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + hs.get_federation_registry().register_edu_handler( + "m.receipt", self._received_remote_receipt + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.receipt", hs.config.worker.writers.receipts, + ) + self.clock = self.hs.get_clock() self.state = hs.get_state_handler() @@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler): if not is_new: return - await self.federation.send_read_receipt(receipt) + if self.federation_sender: + await self.federation_sender.send_read_receipt(receipt) class ReceiptEventSource: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cb5a29bc7e..e001e418f9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.account_data_handler = hs.get_account_data_handler() self.member_linearizer = Linearizer(name="member") @@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): direct_rooms[key].append(new_room_id) # Save back to user's m.direct account data - await self.store.add_account_data_for_user( + await self.account_data_handler.add_account_data_for_user( user_id, AccountDataTypes.DIRECT, direct_rooms ) break @@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) + await self.account_data_handler.add_tag_to_room( + user_id, new_room_id, tag, tag_content + ) async def update_membership( self, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 2da1ea2223..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -22,7 +22,10 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from synapse.api.constants import LoginType from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.handlers.ui_auth import UIAuthSessionDataConstants +from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters @@ -72,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, @@ -145,8 +153,13 @@ class SsoHandler: self._store = hs.get_datastore() self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() - self._error_template = hs.config.sso_error_template self._auth_handler = hs.get_auth_handler() + self._error_template = hs.config.sso_error_template + self._bad_user_template = hs.config.sso_auth_bad_user_template + + # The following template is shown after a successful user interactive + # authentication session. It tells the user they can close the window. + self._sso_auth_success_template = hs.config.sso_auth_success_template # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) @@ -166,6 +179,37 @@ class SsoHandler: """Get the configured identity providers""" return self._identity_providers + async def get_identity_providers_for_user( + self, user_id: str + ) -> Mapping[str, SsoIdentityProvider]: + """Get the SsoIdentityProviders which a user has used + + Given a user id, get the identity providers that that user has used to log in + with in the past (and thus could use to re-identify themselves for UI Auth). + + Args: + user_id: MXID of user to look up + + Raises: + a map of idp_id to SsoIdentityProvider + """ + external_ids = await self._store.get_external_ids_by_user(user_id) + + valid_idps = {} + for idp_id, _ in external_ids: + idp = self._identity_providers.get(idp_id) + if not idp: + logger.warning( + "User %r has an SSO mapping for IdP %r, but this is no longer " + "configured.", + user_id, + idp_id, + ) + else: + valid_idps[idp_id] = idp + + return valid_idps + def render_error( self, request: Request, @@ -362,7 +406,7 @@ class SsoHandler: attributes, auth_provider_id, remote_user_id, - request.get_user_agent(""), + get_request_user_agent(request), request.getClientIP(), ) @@ -545,19 +589,45 @@ class SsoHandler: auth_provider_id, remote_user_id, ) + user_id_to_verify = await self._auth_handler.get_session_data( + ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID + ) # type: str + if not user_id: logger.warning( "Remote user %s/%s has not previously logged in here: UIA will fail", auth_provider_id, remote_user_id, ) - # Let the UIA flow handle this the same as if they presented creds for a - # different user. - user_id = "" + elif user_id != user_id_to_verify: + logger.warning( + "Remote user %s/%s mapped onto incorrect user %s: UIA will fail", + auth_provider_id, + remote_user_id, + user_id, + ) + else: + # success! + # Mark the stage of the authentication as successful. + await self._store.mark_ui_auth_stage_complete( + ui_auth_session_id, LoginType.SSO, user_id + ) + + # Render the HTML confirmation page and return. + html = self._sso_auth_success_template + respond_with_html(request, 200, html) + return + + # the user_id didn't match: mark the stage of the authentication as unsuccessful + await self._store.mark_ui_auth_stage_complete( + ui_auth_session_id, LoginType.SSO, "" + ) - await self._auth_handler.complete_sso_ui_auth( - user_id, ui_auth_session_id, request + # render an error page. + html = self._bad_user_template.render( + server_name=self._server_name, user_id_to_verify=user_id_to_verify, ) + respond_with_html(request, 200, html) async def check_username_availability( self, localpart: str, session_id: str, @@ -628,7 +698,7 @@ class SsoHandler: attributes, session.auth_provider_id, session.remote_user_id, - request.get_user_agent(""), + get_request_user_agent(request), request.getClientIP(), ) diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 824f37f8f8..a68d5e790e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here. """ from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401 + + +class UIAuthSessionDataConstants: + """Constants for use with AuthHandler.set_session_data""" + + # used during registration and password reset to store a hashed copy of the + # password, so that the client does not need to submit it each time. + PASSWORD_HASH = "password_hash" + + # used during registration to store the mxid of the registered user + REGISTERED_USER_ID = "registered_user_id" + + # used by validate_user_via_ui_auth to store the mxid of the user we are validating + # for. + REQUEST_USER_ID = "request_user_id" diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 59b01b812c..4bc3cb53f0 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -17,6 +17,7 @@ import re from twisted.internet import task from twisted.web.client import FileBodyProducer +from twisted.web.iweb import IRequest from synapse.api.errors import SynapseError @@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer): FileBodyProducer.stopProducing(self) except task.TaskStopped: pass + + +def get_request_user_agent(request: IRequest, default: str = "") -> str: + """Return the last User-Agent header, or the given default. + """ + # There could be raw utf-8 bytes in the User-Agent header. + + # N.B. if you don't do this, the logger explodes cryptically + # with maximum recursion trying to log errors about + # the charset problem. + # c.f. https://github.com/matrix-org/synapse/issues/3471 + + h = request.getHeader(b"User-Agent") + return h.decode("ascii", "replace") if h else default diff --git a/synapse/http/client.py b/synapse/http/client.py index 29f40ddf5f..37ccf5ab98 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -32,7 +32,7 @@ from typing import ( import treq from canonicaljson import encode_canonical_json -from netaddr import IPAddress, IPSet +from netaddr import AddrFormatError, IPAddress, IPSet from prometheus_client import Counter from zope.interface import implementer, provider @@ -261,16 +261,16 @@ class BlacklistingAgentWrapper(Agent): try: ip_address = IPAddress(h.hostname) - + except AddrFormatError: + # Not an IP + pass + else: if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) - except Exception: - # Not an IP - pass return self._agent.request( method, uri, headers=headers, bodyProducer=bodyProducer @@ -341,6 +341,7 @@ class SimpleHttpClient: self.agent = ProxyAgent( self.reactor, + hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, @@ -723,7 +724,7 @@ class SimpleHttpClient: read_body_with_max_size(response, output_stream, max_size) ) except BodyExceededMaxSize: - SynapseError( + raise SynapseError( 502, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, @@ -765,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): self.max_size = max_size def dataReceived(self, data: bytes) -> None: + # If the deferred was called, bail early. + if self.deferred.called: + return + self.stream.write(data) self.length += len(data) + # The first time the maximum size is exceeded, error and cancel the + # connection. dataReceived might be called again if data was received + # in the meantime. if self.max_size is not None and self.length >= self.max_size: self.deferred.errback(BodyExceededMaxSize()) - self.deferred = defer.Deferred() self.transport.loseConnection() def connectionLost(self, reason: Failure) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 3b756a7dc2..4c06a117d3 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -102,7 +102,6 @@ class MatrixFederationAgent: pool=self._pool, contextFactory=tls_client_options_factory, ), - self._reactor, ip_blacklist=ip_blacklist, ), user_agent=self.user_agent, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b261e078c4..19293bf673 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -174,6 +174,16 @@ async def _handle_json_response( d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = await make_deferred_yieldable(d) + except ValueError as e: + # The JSON content was invalid. + logger.warning( + "{%s} [%s] Failed to parse JSON response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=False) from e except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", @@ -986,7 +996,7 @@ class MatrixFederationHttpClient: logger.warning( "{%s} [%s] %s", request.txn_id, request.destination, msg, ) - SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(502, msg, Codes.TOO_LARGE) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index e32d3f43e0..b730d2c634 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase): reactor: twisted reactor to place outgoing connections. + proxy_reactor: twisted reactor to use for connections to the proxy server + reactor might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the verification parameters of OpenSSL. The default is to use a `BrowserLikePolicyForHTTPS`, so unless you have special @@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase): def __init__( self, reactor, + proxy_reactor=None, contextFactory=BrowserLikePolicyForHTTPS(), connectTimeout=None, bindAddress=None, @@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase): ): _AgentBase.__init__(self, reactor, pool) + if proxy_reactor is None: + self.proxy_reactor = reactor + else: + self.proxy_reactor = proxy_reactor + self._endpoint_kwargs = {} if connectTimeout is not None: self._endpoint_kwargs["timeout"] = connectTimeout @@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase): self._endpoint_kwargs["bindAddress"] = bindAddress self.http_proxy_endpoint = _http_proxy_endpoint( - http_proxy, reactor, **self._endpoint_kwargs + http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) self.https_proxy_endpoint = _http_proxy_endpoint( - https_proxy, reactor, **self._endpoint_kwargs + https_proxy, self.proxy_reactor, **self._endpoint_kwargs ) self._policy_for_https = contextFactory @@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase): request_path = uri elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: endpoint = HTTPConnectProxyEndpoint( - self._reactor, + self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, diff --git a/synapse/http/site.py b/synapse/http/site.py index 5a5790831b..12ec3f851f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ from twisted.python.failure import Failure from twisted.web.server import Request, Site from synapse.config.server import ListenerConfig -from synapse.http import redact_uri +from synapse.http import get_request_user_agent, redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.types import Requester @@ -113,15 +113,6 @@ class SynapseRequest(Request): method = self.method.decode("ascii") return method - def get_user_agent(self, default: str) -> str: - """Return the last User-Agent header, or the given default. - """ - user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] - if user_agent is None: - return default - - return user_agent.decode("ascii", "replace") - def render(self, resrc): # this is called once a Resource has been found to serve the request; in our # case the Resource in question will normally be a JsonResource. @@ -292,12 +283,7 @@ class SynapseRequest(Request): # and can see that we're doing something wrong. authenticated_entity = repr(self.requester) # type: ignore[unreachable] - # ...or could be raw utf-8 bytes in the User-Agent header. - # N.B. if you don't do this, the logger explodes cryptically - # with maximum recursion trying to log errors about - # the charset problem. - # c.f. https://github.com/matrix-org/synapse/issues/3471 - user_agent = self.get_user_agent("-") + user_agent = get_request_user_agent(self, "-") code = str(self.code) if not self.finished: diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a84a064c8d..dd527e807f 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -15,6 +15,7 @@ from synapse.http.server import JsonResource from synapse.replication.http import ( + account_data, devices, federation, login, @@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource): presence.register_servlets(hs, self) membership.register_servlets(hs, self) streams.register_servlets(hs, self) + account_data.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1492ac922c..288727a566 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -177,7 +177,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() - async def send_request(instance_name="master", **kwargs): + async def send_request(*, instance_name="master", **kwargs): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") if instance_name == "master": diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py new file mode 100644 index 0000000000..52d32528ee --- /dev/null +++ b/synapse/replication/http/account_data.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): + """Add user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): + """Add room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "add_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddTagRestServlet(ReplicationEndpoint): + """Add tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_tag/:user_id/:room_id/:tag + + { + "content": { ... }, + } + + """ + + NAME = "add_tag" + PATH_ARGS = ("user_id", "room_id", "tag") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, tag): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_tag_to_room( + user_id, room_id, tag, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRemoveTagRestServlet(ReplicationEndpoint): + """Remove tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag + + {} + + """ + + NAME = "remove_tag" + PATH_ARGS = ( + "user_id", + "room_id", + "tag", + ) + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag): + + return {} + + async def _handle_request(self, request, user_id, room_id, tag): + max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + + return 200, {"max_stream_id": max_stream_id} + + +def register_servlets(hs, http_server): + ReplicationUserAccountDataRestServlet(hs).register(http_server) + ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddTagRestServlet(hs).register(http_server) + ReplicationRemoveTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d0089fe06c..693c9ab901 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) # type: Optional[MultiWriterIdGenerator] diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 4268565fc8..21afe5f155 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -15,47 +15,9 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "account_data", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self): - return self._account_data_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - if not row.room_id: - self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id) - ) - self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) - self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type) - ) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6195917376..3dfdd9961d 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,43 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) - - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): - self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) - ) - self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) - self.get_receipts_for_room.invalidate((room_id, receipt_type)) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(instance_name, token) - for row in rows: - self.invalidate_caches_for_receipt( - row.room_id, row.receipt_type, row.user_id - ) - self._receipts_stream_cache.entity_has_changed(row.room_id, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1f89249475..317796d5e0 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + AccountDataStream, BackfillStream, CachesStream, EventsStream, FederationStream, + ReceiptsStream, Stream, + TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -132,6 +135,22 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + # Only add AccountDataStream and TagAccountDataStream as a source on the + # instance in charge of account_data persistence. + if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._streams_to_replicate.append(stream) + + continue + + if isinstance(stream, ReceiptsStream): + # Only add ReceiptsStream as a source on the instance in charge of + # receipts. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html new file mode 100644 index 0000000000..3611191bf9 --- /dev/null +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -0,0 +1,18 @@ +<html> +<head> + <title>Authentication Failed</title> +</head> + <body> + <div> + <p> + We were unable to validate your <tt>{{server_name | e}}</tt> account via + single-sign-on (SSO), because the SSO Identity Provider returned + different details than when you logged in. + </p> + <p> + Try the operation again, and ensure that you use the same details on + the Identity Provider as when you log into your account. + </p> + </div> + </body> +</html> diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@ <li> <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}"> <label for="prov{{loop.index}}">{{p.idp_name | e}}</label> +{% if p.idp_icon %} + <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/> +{% endif %} </li> {% endfor %} </ul> diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c82b4f87d6..8720b1401f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,6 +15,9 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -23,6 +26,10 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet): admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id: str): + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id: str): + async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet): "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, server_name: str, media_id: str): + async def on_POST( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class ProtectMediaByID(RestServlet): + """Protect local media from being quarantined. + """ + + PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Protecting local media by ID: %s", media_id) + + # Quarantine this media id + await self.store.mark_local_media_as_safe(media_id) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): PATTERNS = admin_patterns("/purge_media_cache") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_DELETE(self, request, server_name: str, media_id: str): + async def on_DELETE( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: @@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request, server_name: str): + async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} -def register_servlets_for_media_repo(hs, http_server): +def register_servlets_for_media_repo(hs: "HomeServer", http_server): """ Media repo specific APIs. """ @@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) + ProtectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6658c2da56..86198bab30 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret @@ -244,7 +259,7 @@ class UserRestServletV2(RestServlet): if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( - target_user.to_string(), False + target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: if "password" not in body: @@ -486,12 +501,22 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + + async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(400, "Can only deactivate local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") - async def on_POST(self, request, target_user_id): - await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -501,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - UserID.from_string(target_user_id) - result = await self._deactivate_account_handler.deactivate_account( - target_user_id, erase + target_user_id, erase, requester, by_admin=True ) if result: id_server_unbind_result = "success" @@ -714,13 +737,6 @@ class UserMembershipRestServlet(RestServlet): async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} return 200, ret diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5647e8c577..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server @@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet): # provided. if server: raise e - else: - pass limit = parse_integer(request, "limit", 0) since_token = parse_string(request, "since", None) @@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + ) + try: data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token @@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + ) + try: data = await handler.get_remote_public_room_list( server, diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index d837bde1d6..65e68d641b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,9 +20,6 @@ from http import HTTPStatus from typing import TYPE_CHECKING from urllib.parse import urlparse -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -31,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, ) from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, @@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + logger = logging.getLogger(__name__) @@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) try: params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", + requester, request, body, "modify your account password", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth, but @@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet): if new_password: password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise user_id = requester.user.to_string() @@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet): [[LoginType.EMAIL_IDENTITY]], request, body, - self.hs.get_ip_from_request(request), "modify your account password", ) except InteractiveAuthIncompleteError as e: @@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet): if new_password: password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise @@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet): password_hash = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = await self.auth_handler.get_session_data( - session_id, "password_hash", None + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) else: # UI validation was skipped, but the request did not include a new @@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase + requester.user.to_string(), erase, requester ) return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "deactivate your account", + requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, id_server=body.get("id_server") + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" @@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "add a third-party identifier to your account", + requester, request, body, "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 87a5b1b86b..3f28c0bc3e 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = await self.store.add_account_data_for_user( - user_id, account_data_type, body - ) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = await self.store.add_account_data_to_room( + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return 200, {} async def on_GET(self, request, user_id, room_id, account_data_type): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 9b9514632f..75ece1c911 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -19,7 +19,6 @@ from typing import TYPE_CHECKING from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.handlers.sso import SsoIdentityProvider from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string @@ -46,22 +45,6 @@ class AuthRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - - # SSO configuration. - self._cas_enabled = hs.config.cas_enabled - if self._cas_enabled: - self._cas_handler = hs.get_cas_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self._saml_enabled = hs.config.saml2_enabled - if self._saml_enabled: - self._saml_handler = hs.get_saml_handler() - self._oidc_enabled = hs.config.oidc_enabled - if self._oidc_enabled: - self._oidc_handler = hs.get_oidc_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template self.success_template = hs.config.fallback_success_template @@ -90,21 +73,7 @@ class AuthRestServlet(RestServlet): elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - - if self._cas_enabled: - sso_auth_provider = self._cas_handler # type: SsoIdentityProvider - elif self._saml_enabled: - sso_auth_provider = self._saml_handler - elif self._oidc_enabled: - sso_auth_provider = self._oidc_handler - else: - raise SynapseError(400, "Homeserver not configured for SSO.") - - sso_redirect_url = await sso_auth_provider.handle_redirect_request( - request, None, session - ) - - html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + html = await self.auth_handler.start_sso_ui_auth(request, session) else: raise SynapseError(404, "Unknown auth stage type") @@ -128,7 +97,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, request.getClientIP() ) if success: @@ -144,7 +113,7 @@ class AuthRestServlet(RestServlet): authdict = {"session": session} success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, request.getClientIP() ) if success: diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index af117cb27c..314e01dfe4 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "remove device(s) from your account", + requester, request, body, "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "remove a device from your account", + requester, request, body, "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index b91996c738..a6134ead8a 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "add a device signing key to your account", + requester, request, body, "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6b5a1b7109..b093183e79 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, @@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet): 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) - ip = self.hs.get_ip_from_request(request) + ip = request.getClientIP() with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet): # user here. We carry on and go through the auth checks though, # for paranoia. registered_user_id = await self.auth_handler.get_session_data( - session_id, "registered_user_id", None + session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None ) # Extract the previously-hashed password from the session. password_hash = await self.auth_handler.get_session_data( - session_id, "password_hash", None + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) # Ensure that the username is valid. @@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet): # not this will raise a user-interactive auth error. try: auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", + self._registration_flows, request, body, "register a new account", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth. @@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet): if not password_hash and password: password_hash = await self.auth_handler.hash(password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise @@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet): # Remember that the user account has been registered (and the user # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( - session_id, "registered_user_id", registered_user_id + session_id, + UIAuthSessionDataConstants.REGISTERED_USER_ID, + registered_user_id, ) registered = True diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index bf3a79db44..a97cd66c52 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -58,8 +58,7 @@ class TagServlet(RestServlet): def __init__(self, hs): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request) @@ -68,9 +67,7 @@ class TagServlet(RestServlet): body = parse_json_object_from_request(request) - max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_tag_to_room(user_id, room_id, tag, body) return 200, {} @@ -79,9 +76,7 @@ class TagServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.remove_tag_from_room(user_id, room_id, tag) return 200, {} diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 47c2b44bff..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ import logging import os import urllib -from typing import Awaitable +from typing import Awaitable, Dict, Generator, List, Optional, Tuple from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json @@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [ ] -def parse_media_id(request): +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. @@ -69,7 +70,7 @@ def parse_media_id(request): ) -def respond_404(request): +def respond_404(request: Request) -> None: respond_with_json( request, 404, @@ -79,8 +80,12 @@ def respond_404(request): async def respond_with_file( - request, media_type, file_path, file_size=None, upload_name=None -): + request: Request, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -98,15 +103,20 @@ async def respond_with_file( respond_404(request) -def add_file_headers(request, media_type, file_size, upload_name): +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: """Adds the correct response headers in preparation for responding with the media. Args: - request (twisted.web.http.Request) - media_type (str): The media/content type. - file_size (int): Size in bytes of the media, if known. - upload_name (str): The name of the requested file, if any. + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. """ def _quote(x): @@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name): # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - request.setHeader(b"Content-Length", b"%d" % (file_size,)) + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) # Tell web crawlers to not index, archive, or follow links in media. This # should help to prevent things in the media repo from showing up in web @@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = { } -def _can_encode_filename_as_token(x): +def _can_encode_filename_as_token(x: str) -> bool: for c in x: # from RFC2616: # @@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x): async def respond_with_responder( - request, responder, media_type, file_size, upload_name=None -): + request: Request, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: """Responds to the request with given responder. If responder is None then returns 404. Args: - request (twisted.web.http.Request) - responder (Responder|None) - media_type (str): The media/content type. - file_size (int|None): Size in bytes of the media. If not known it should be None - upload_name (str|None): The name of the requested file, if any. + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. """ if request._disconnected: logger.warning( @@ -285,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -297,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -306,24 +323,25 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length -def get_filename_from_headers(headers): +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: """ Get the filename of the downloaded file by inspecting the Content-Disposition HTTP header. Args: - headers (dict[bytes, list[bytes]]): The HTTP request headers. + headers: The HTTP request headers. Returns: - A Unicode string of the filename, or None. + The filename, or None. """ content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: - return + return None _, params = _parse_header(content_disposition[0]) @@ -356,17 +374,16 @@ def get_filename_from_headers(headers): return upload_name -def _parse_header(line): +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: """Parse a Content-type like header. Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - line (bytes): header to be parsed + line: header to be parsed Returns: - Tuple[bytes, dict[bytes, bytes]]: - the main content-type, followed by the parameter dictionary + The main content-type, followed by the parameter dictionary """ parts = _parseparam(b";" + line) key = next(parts) @@ -386,16 +403,16 @@ def _parse_header(line): return key, pdict -def _parseparam(s): +def _parseparam(s: bytes) -> Generator[bytes, None, None]: """Generator which splits the input on ;, respecting double-quoted sequences Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - s (bytes): header to be parsed + s: header to be parsed Returns: - Iterable[bytes]: the split input + The split input """ while s[:1] == b";": s = s[1:] diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 68dd2a1c8a..4e4c6971f7 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 Will Hunt <will@half-shot.uk> +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +15,29 @@ # limitations under the License. # +from typing import TYPE_CHECKING + +from twisted.web.http import Request + from synapse.http.server import DirectServeJsonResource, respond_with_json +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class MediaConfigResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d3d8457303..3ed219ae43 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request -import synapse.http.servlet from synapse.http.server import DirectServeJsonResource, set_cors_headers +from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class DownloadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", @@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource): if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) else: - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True - ) + allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 9e079f672f..7792f26e78 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +17,12 @@ import functools import os import re +from typing import Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func): +def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ @@ -41,12 +43,18 @@ class MediaFilePaths: to write to the backup media store (when one is configured) """ - def __init__(self, primary_base_path): + def __init__(self, primary_base_path: str): self.base_path = primary_base_path def default_thumbnail_rel( - self, default_top_level, default_sub_type, width, height, content_type, method - ): + self, + default_top_level: str, + default_sub_type: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -55,12 +63,14 @@ class MediaFilePaths: default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) - def local_media_filepath_rel(self, media_id): + def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -86,7 +96,7 @@ class MediaFilePaths: media_id[4:], ) - def remote_media_filepath_rel(self, server_name, file_id): + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) @@ -94,8 +104,14 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) def remote_media_thumbnail_rel( - self, server_name, file_id, width, height, content_type, method - ): + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -113,7 +129,7 @@ class MediaFilePaths: # Should be removed after some time, when most of the thumbnails are stored # using the new path. def remote_media_thumbnail_rel_legacy( - self, server_name, file_id, width, height, content_type + self, server_name: str, file_id: str, width: int, height: int, content_type: str ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) @@ -126,7 +142,7 @@ class MediaFilePaths: file_name, ) - def remote_media_thumbnail_dir(self, server_name, file_id): + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", @@ -136,7 +152,7 @@ class MediaFilePaths: file_id[4:], ) - def url_cache_filepath_rel(self, media_id): + def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -146,7 +162,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - def url_cache_filepath_dirs_to_delete(self, media_id): + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): return [os.path.join(self.base_path, "url_cache", media_id[:10])] @@ -156,7 +172,9 @@ class MediaFilePaths: os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -178,7 +196,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id): + def url_cache_thumbnail_directory(self, media_id: str) -> str: # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -195,7 +213,7 @@ class MediaFilePaths: media_id[4:], ) - def url_cache_thumbnail_dirs_to_delete(self, media_id): + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 83beb02b05..4c9946a616 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import errno import logging import os import shutil -from typing import IO, Dict, List, Optional, Tuple +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple import twisted.internet.error import twisted.web.http @@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.client = hs.get_federation_http_client() @@ -73,16 +76,16 @@ class MediaRepository: self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.primary_base_path = hs.config.media_store_path - self.filepaths = MediaFilePaths(self.primary_base_path) + self.primary_base_path = hs.config.media_store_path # type: str + self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() - self.recently_accessed_locals = set() + self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] + self.recently_accessed_locals = set() # type: Set[str] self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -113,7 +116,7 @@ class MediaRepository: "update_recently_accessed_media", self._update_recently_accessed ) - async def _update_recently_accessed(self): + async def _update_recently_accessed(self) -> None: remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() @@ -124,12 +127,12 @@ class MediaRepository: local_media, remote_media, self.clock.time_msec() ) - def mark_recently_accessed(self, server_name, media_id): + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: """Mark the given media as recently accessed. Args: - server_name (str|None): Origin server of media, or None if local - media_id (str): The media ID of the content + server_name: Origin server of media, or None if local + media_id: The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) @@ -459,7 +462,14 @@ class MediaRepository: def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: m_width = thumbnailer.width m_height = thumbnailer.height @@ -470,22 +480,20 @@ class MediaRepository: m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = thumbnailer.transpose() if t_method == "crop": - t_byte_source = thumbnailer.crop(t_width, t_height, t_type) + return thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_byte_source = thumbnailer.scale(t_width, t_height, t_type) - else: - t_byte_source = None + return thumbnailer.scale(t_width, t_height, t_type) - return t_byte_source + return None async def generate_local_exact_thumbnail( self, @@ -776,7 +784,7 @@ class MediaRepository: return {"width": m_width, "height": m_height} - async def delete_old_remote_media(self, before_ts): + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource): within a given rectangle. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. if not hs.config.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 268e0c8f50..89cdd605aa 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vecotr Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import os import shutil from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable @@ -270,7 +272,7 @@ class MediaStorage: return self.filepaths.local_media_filepath_rel(file_info.file_id) -def _write_file_synchronously(source, dest): +def _write_file_synchronously(source: IO, dest: IO) -> None: """Write `source` to the file like `dest` synchronously. Should be called from a thread. @@ -286,14 +288,14 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file (file): A file like object to be streamed ot the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ - def __init__(self, open_file): + def __init__(self, open_file: IO): self.open_file = open_file - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Deferred: return make_deferred_yieldable( FileSender().beginFileTransfer(self.open_file, consumer) ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 1082389d9b..a632099167 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import datetime import errno import fnmatch @@ -23,12 +23,13 @@ import re import shutil import sys import traceback -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse import attr from twisted.internet.error import DNSLookupError +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string from ._base import FileInfo +if TYPE_CHECKING: + from lxml import etree + + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) @@ -119,7 +127,12 @@ class OEmbedError(Exception): class PreviewUrlResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.auth = hs.get_auth() @@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) @@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url: str, user): + async def _download_url(self, url: str, user: str) -> Dict[str, Any]: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource): "expire_url_cache_data", self._expire_url_cache_data ) - async def _expire_url_cache_data(self): + async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: +def decode_and_calc_og( + body: bytes, media_uri: str, request_encoding: Optional[str] = None +) -> Dict[str, Optional[str]]: # If there's no body, nothing useful is going to be found. if not body: return {} @@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str] return og -def _calc_og(tree, media_uri): +def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them @@ -801,7 +816,9 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og["og:description"] = summarize_paragraphs(text_nodes) - else: + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, @@ -809,7 +826,9 @@ def _calc_og(tree, media_uri): return og -def _iterate_over_text(tree, *tags_to_ignore): +def _iterate_over_text( + tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] +) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. """ @@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore): ) -def _rebase_url(url, base): - base = list(urlparse.urlparse(base)) - url = list(urlparse.urlparse(url)) - if not url[0]: # fix up schema - url[0] = base[0] or "http" - if not url[1]: # fix up hostname - url[1] = base[1] - if not url[2].startswith("/"): - url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] - return urlparse.urlunparse(url) +def _rebase_url(url: str, base: str) -> str: + base_parts = list(urlparse.urlparse(base)) + url_parts = list(urlparse.urlparse(url)) + if not url_parts[0]: # fix up schema + url_parts[0] = base_parts[0] or "http" + if not url_parts[1]: # fix up hostname + url_parts[1] = base_parts[1] + if not url_parts[2].startswith("/"): + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + return urlparse.urlunparse(url_parts) -def _is_media(content_type): - if content_type.lower().startswith("image/"): - return True +def _is_media(content_type: str) -> bool: + return content_type.lower().startswith("image/") -def _is_html(content_type): +def _is_html(content_type: str) -> bool: content_type = content_type.lower() - if content_type.startswith("text/html") or content_type.startswith( + return content_type.startswith("text/html") or content_type.startswith( "application/xhtml" - ): - return True + ) -def summarize_paragraphs(text_nodes, min_size=200, max_size=500): +def summarize_paragraphs( + text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 +) -> Optional[str]: # Try to get a summary of between 200 and 500 words, respecting # first paragraph and then word boundaries. # TODO: Respect sentences? diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 67f67efde7..e92006faa9 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import os import shutil -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -27,13 +28,17 @@ from .media_storage import FileResponder logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer -class StorageProvider: + +class StorageProvider(metaclass=abc.ABCMeta): """A storage provider is a service that can store uploaded media and retrieve them. """ - async def store_file(self, path: str, file_info: FileInfo): + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -42,6 +47,7 @@ class StorageProvider: file_info: The metadata of the file. """ + @abc.abstractmethod async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. @@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider): self.store_synchronous = store_synchronous self.store_remote = store_remote - def __str__(self): + def __str__(self) -> str: return "StorageProviderWrapper[%s]" % (self.backend,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: if not file_info.server_name and not self.store_local: return None @@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - return await maybe_awaitable(self.backend.store_file(path, file_info)) + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore else: # TODO: Handle errors. async def store(): @@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return None - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: # store_file is supposed to return an Awaitable, but guard # against improper implementations. return await maybe_awaitable(self.backend.fetch(path, file_info)) @@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider): """A storage provider that stores files in a directory on a filesystem. Args: - hs (HomeServer) + hs config: The config returned by `parse_config`. """ - def __init__(self, hs, config): + def __init__(self, hs: "HomeServer", config: str): self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config @@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return await defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): return FileResponder(open(backup_fname, "rb")) + return None + @staticmethod - def parse_config(config): + def parse_config(config: dict) -> str: """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 30421b663a..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,14 @@ import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from twisted.web.http import Request from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( FileInfo, @@ -28,13 +33,22 @@ from ._base import ( respond_with_responder, ) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class ThumbnailResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.store = hs.get_datastore() @@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource): self.media_repo.mark_recently_accessed(server_name, media_id) async def _respond_local_thumbnail( - self, request, media_id, width, height, method, m_type - ): + self, + request: Request, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, - request, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request, - server_name, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + server_name: str, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = await self.store.get_remote_media_thumbnails( @@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource): raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( - self, request, server_name, media_id, width, height, method, m_type - ): + self, + request: Request, + server_name: str, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. @@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) def _select_thumbnail( self, - desired_width, - desired_height, - desired_method, - desired_type, - thumbnail_infos, - ): + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 32a8e4f960..07903e4017 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. import logging from io import BytesIO +from typing import Tuple from PIL import Image @@ -39,7 +41,7 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - def __init__(self, input_path): + def __init__(self, input_path: str): try: self.image = Image.open(input_path) except OSError as e: @@ -59,11 +61,11 @@ class Thumbnailer: # A lot of parsing errors can happen when parsing EXIF logger.info("Error parsing image EXIF information: %s", e) - def transpose(self): + def transpose(self) -> Tuple[int, int]: """Transpose the image using its EXIF Orientation tag Returns: - Tuple[int, int]: (width, height) containing the new image size in pixels. + A tuple containing the new image size in pixels as (width, height). """ if self.transpose_method is not None: self.image = self.image.transpose(self.transpose_method) @@ -73,7 +75,7 @@ class Thumbnailer: self.image.info["exif"] = None return self.image.size - def aspect(self, max_width, max_height): + def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: """Calculate the largest size that preserves aspect ratio which fits within the given rectangle:: @@ -91,7 +93,7 @@ class Thumbnailer: else: return (max_height * self.width) // self.height, max_height - def _resize(self, width, height): + def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which # looks awful @@ -99,7 +101,7 @@ class Thumbnailer: self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) - def scale(self, width, height, output_type): + def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. Returns: @@ -108,7 +110,7 @@ class Thumbnailer: scaled = self._resize(width, height) return self._encode_image(scaled, output_type) - def crop(self, width, height, output_type): + def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -136,7 +138,7 @@ class Thumbnailer: cropped = scaled_image.crop((crop_left, 0, crop_right, height)) return self._encode_image(cropped, output_type) - def _encode_image(self, output_image, output_type): + def _encode_image(self, output_image: Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() fmt = self.FORMATS[output_type] if fmt == "JPEG": diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 42febc9afc..6da76ae994 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class UploadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo @@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: Request) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index e5b720bbca..9550b82998 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource): self._server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) idp = parse_string(request, "idp", required=False) # if we need to pick an IdP, do so diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index f591cc6c5c..241fe746d9 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,10 +34,6 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): - # if we don't have a public_baseurl, we can't help much here. - if self._config.public_baseurl is None: - return None - result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/synapse/server.py b/synapse/server.py index a198b0eb46..9cdda83aa1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.admin import AdminHandler @@ -283,10 +284,6 @@ class HomeServer(metaclass=abc.ABCMeta): """ return self._reactor - def get_ip_from_request(self, request) -> str: - # X-Forwarded-For is handled by our custom request type. - return request.getClientIP() - def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: return domain_specific_string.domain == self.hostname @@ -505,7 +502,7 @@ class HomeServer(metaclass=abc.ABCMeta): return InitialSyncHandler(self) @cache_in_self - def get_profile_handler(self): + def get_profile_handler(self) -> ProfileHandler: return ProfileHandler(self) @cache_in_self @@ -715,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_module_api(self) -> ModuleApi: return ModuleApi(self, self.get_auth_handler()) + @cache_in_self + def get_account_data_handler(self) -> AccountDataHandler: + return AccountDataHandler(self) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 2258d306d9..8dd01fce76 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -42,6 +42,7 @@ class ResourceLimitsServerNotices: self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False + self._account_data_handler = hs.get_account_data_handler() self._message_handler = hs.get_message_handler() self._state = hs.get_state_handler() @@ -177,7 +178,7 @@ class ResourceLimitsServerNotices: # tag already present, nothing to do here need_to_set_tag = False if need_to_set_tag: - max_id = await self._store.add_tag_to_room( + max_id = await self._account_data_handler.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 100dbd5e2c..c46b2f047d 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -35,6 +35,7 @@ class ServerNoticesManager: self._store = hs.get_datastore() self._config = hs.config + self._account_data_handler = hs.get_account_data_handler() self._room_creation_handler = hs.get_room_creation_handler() self._room_member_handler = hs.get_room_member_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -163,7 +164,7 @@ class ServerNoticesManager: ) room_id = info["room_id"] - max_id = await self._store.add_tag_to_room( + max_id = await self._account_data_handler.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b70ca3087b..d2ba4bd2fc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection # python 3 does not have a maximum int value @@ -179,6 +180,9 @@ class LoggingDatabaseConnection: _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] +R = TypeVar("R") + + class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -258,13 +262,32 @@ class LoggingTransaction: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: + """Similar to `executemany`, except `txn.rowcount` will not be correct + afterwards. + + More efficient than `executemany` on PostgreSQL + """ + if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch # type: ignore self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: - for val in args: - self.execute(sql, val) + self.executemany(sql, args) + + def execute_values(self, sql: str, *args: Any) -> List[Tuple]: + """Corresponds to psycopg2.extras.execute_values. Only available when + using postgres. + + Always sets fetch=True when caling `execute_values`, so will return the + results. + """ + assert isinstance(self.database_engine, PostgresEngine) + from psycopg2.extras import execute_values # type: ignore + + return self._do_execute( + lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args + ) def execute(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.execute, sql, *args) @@ -276,7 +299,7 @@ class LoggingTransaction: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) - def _do_execute(self, func, sql: str, *args: Any) -> None: + def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R: sql = self._make_sql_one_line(sql) # TODO(paul): Maybe use 'info' and 'debug' for values? @@ -347,9 +370,6 @@ class PerformanceCounters: return top_n_counters -R = TypeVar("R") - - class DatabasePool: """Wraps a single physical database and connection pool. @@ -398,6 +418,16 @@ class DatabasePool: self._check_safe_to_upsert, ) + # We define this sequence here so that it can be referenced from both + # the DataStore and PersistEventStore. + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self.event_chain_id_gen = build_sequence_generator( + engine, get_chain_id_txn, "event_auth_chain_id" + ) + def is_running(self) -> bool: """Is the database pool currently running """ @@ -863,7 +893,7 @@ class DatabasePool: ", ".join("?" for _ in keys[0]), ) - txn.executemany(sql, vals) + txn.execute_batch(sql, vals) async def simple_upsert( self, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index b936f54f1e..5d0845588c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -162,9 +162,13 @@ class DataStore( database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index bff51e92b9..a277a1ef13 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): +class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) + + self._account_data_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[ + ("room_account_data", "instance_name", "stream_id"), + ("room_tags_revisions", "instance_name", "stream_id"), + ("account_data", "instance_name", "stream_id"), + ], + sequence_name="account_data_sequence", + writers=hs.config.worker.writers.account_data, + ) + else: + self._can_write_to_account_data = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + else: + self._account_data_id_gen = SlavedIdTracker( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max @@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): super().__init__(database, db_conn, hs) - @abc.abstractmethod - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream Returns: int """ - raise NotImplementedError() + return self._account_data_id_gen.get_current_token() @cached() async def get_account_data_for_user( @@ -307,28 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): ) ) - -class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "account_data_max_stream_id", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self) -> int: - """Get the current max stream id for the private user data stream - - Returns: - The maximum stream ID. - """ - return self._account_data_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type) + ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -344,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: @@ -362,14 +406,6 @@ class AccountDataStore(AccountDataWorkerStore): lock=False, ) - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - await self._update_max_stream_id(next_id) - self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) @@ -392,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "add_user_account_data", @@ -402,18 +440,6 @@ class AccountDataStore(AccountDataWorkerStore): content, ) - # it's theoretically possible for the above to succeed and the - # below to fail - in which case we might reuse a stream id on - # restart, and the above update might not get propagated. That - # doesn't sound any worse than the whole update getting lost, - # which is what would happen if we combined the two into one - # transaction. - # - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - await self._update_max_stream_id(next_id) - self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( @@ -487,23 +513,6 @@ class AccountDataStore(AccountDataWorkerStore): for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) - async def _update_max_stream_id(self, next_id: int) -> None: - """Update the max stream_id - - Args: - next_id: The the revision to advance to. - """ - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - - def _update(txn): - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - - await self.db_pool.runInteraction("update_account_data_max_stream_id", _update) +class AccountDataStore(AccountDataWorkerStore): + pass diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c53c836337..ea1e8fb580 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -407,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): "_prune_old_user_ips", _prune_old_user_ips_txn ) + async def get_last_client_ip_by_device( + self, user_id: str, device_id: Optional[str] + ) -> Dict[Tuple[str, str], dict]: + """For each device_id listed, give the user_ip it was last seen on. + + The result might be slightly out of date as client IPs are inserted in batches. + + Args: + user_id: The user to fetch devices for. + device_id: If None fetches all devices for the user + + Returns: + A dictionary mapping a tuple of (user_id, device_id) to dicts, with + keys giving the column names from the devices table. + """ + + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + res = await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ) + + return {(d["user_id"], d["device_id"]): d for d in res} + class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -512,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore): A dictionary mapping a tuple of (user_id, device_id) to dicts, with keys giving the column names from the devices table. """ + ret = await super().get_last_client_ip_by_device(user_id, device_id) - keyvalues = {"user_id": user_id} - if device_id is not None: - keyvalues["device_id"] = device_id - - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), - ) - - ret = {(d["user_id"], d["device_id"]): d for d in res} + # Update what is retrieved from the database with data which is pending insertion. for key in self._batch_row_update: uid, access_token, ip = key if uid == user_id: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 58d3f71e45..31f70ac5ef 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): db=database, stream_name="to_device", instance_name=self._instance_name, - table="device_inbox", - instance_column="instance_name", - id_column="stream_id", + tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9097677648..659d8f245f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) logger.info("Pruned %d device list outbound pokes", count) @@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Delete older entries in the table, as we really only care about # when the latest change happened. - txn.executemany( + txn.execute_batch( """ DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4d1b92d1aa..c128889bf9 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder @@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( - txn.database_engine, "k.user_id", user_chunk - ) - sql = ( - """ - SELECT k.user_id, k.keytype, k.keydata, k.stream_id - FROM e2e_cross_signing_keys k - INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id - FROM e2e_cross_signing_keys - GROUP BY user_id, keytype) s - USING (user_id, stream_id, keytype) - WHERE - """ - + clause + txn.database_engine, "user_id", user_chunk ) + # Fetch the latest key for each type per user. + if isinstance(self.database_engine, PostgresEngine): + # The `DISTINCT ON` clause will pick the *first* row it + # encounters, so ordering by stream ID desc will ensure we get + # the latest key. + sql = """ + SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id + FROM e2e_cross_signing_keys + WHERE %(clause)s + ORDER BY user_id, keytype, stream_id DESC + """ % { + "clause": clause + } + else: + # SQLite has special handling for bare columns when using + # MIN/MAX with a `GROUP BY` clause where it picks the value from + # a row that matches the MIN/MAX. + sql = """ + SELECT user_id, keytype, keydata, MAX(stream_id) + FROM e2e_cross_signing_keys + WHERE %(clause)s + GROUP BY user_id, keytype + """ % { + "clause": clause + } + txn.execute(sql, params) rows = self.db_pool.cursor_to_dict(txn) @@ -707,50 +722,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """Get the current stream id from the _device_list_id_gen""" ... - -class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - async def set_e2e_device_keys( - self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict - ) -> bool: - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - - def _set_e2e_device_keys_txn(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", device_keys) - - old_key_json = self.db_pool.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False - - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True - - return await self.db_pool.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn - ) - async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] ) -> Dict[str, Dict[str, Dict[str, bytes]]]: @@ -840,6 +811,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + async def set_e2e_device_keys( + self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict + ) -> bool: + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + + def _set_e2e_device_keys_txn(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", device_keys) + + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False + + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, + ) + log_kv({"message": "Device keys stored."}) + return True + + return await self.db_pool.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) + async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn): log_kv( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ebffd89251..8326640d20 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Cursor from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) +class _NoChainCoverIndex(Exception): + def __init__(self, room_id: str): + super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) + + class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas The set of the difference in auth chains. """ + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_difference_chains", + self._get_auth_chain_difference_using_cover_index_txn, + room_id, + state_sets, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, ) + def _get_auth_chain_difference_using_cover_index_txn( + self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: + """Calculates the auth chain difference using the chain index. + + See docs/auth_chain_difference_algorithm.md for details + """ + + # First we look up the chain ID/sequence numbers for all the events, and + # work out the chain/sequence numbers reachable from each state set. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Map from event_id -> (chain ID, seq no) + chain_info = {} # type: Dict[str, Tuple[int, int]] + + # Map from chain ID -> seq no -> event Id + chain_to_event = {} # type: Dict[int, Dict[int, str]] + + # All the chains that we've found that are reachable from the state + # sets. + seen_chains = set() # type: Set[int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + chain_info[event_id] = (chain_id, sequence_number) + seen_chains.add(chain_id) + chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(chain_info) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Corresponds to `state_sets`, except as a map from chain ID to max + # sequence number reachable from the state set. + set_to_chain = [] # type: List[Dict[int, int]] + for state_set in state_sets: + chains = {} # type: Dict[int, int] + set_to_chain.append(chains) + + for event_id in state_set: + chain_id, seq_no = chain_info[event_id] + + chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) + + # Now we look up all links for the chains we have, adding chains to + # set_to_chain that are reachable from each set. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # (We need to take a copy of `seen_chains` as we want to mutate it in + # the loop) + for batch in batch_iter(set(seen_chains), 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + for chains in set_to_chain: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, chains.get(target_chain_id, 0), + ) + + seen_chains.add(target_chain_id) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* state set and the minimum sequence number reachable from + # *all* state sets. Events in that range are in the auth chain + # difference. + result = set() + + # Mapping from chain ID to the range of sequence numbers that should be + # pulled from the database. + chain_to_gap = {} # type: Dict[int, Tuple[int, int]] + + for chain_id in seen_chains: + min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) + max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) + + if min_seq_no < max_seq_no: + # We have a non empty gap, try and fill it from the events that + # we have, otherwise add them to the list of gaps to pull out + # from the DB. + for seq_no in range(min_seq_no + 1, max_seq_no + 1): + event_id = chain_to_event.get(chain_id, {}).get(seq_no) + if event_id: + result.add(event_id) + else: + chain_to_gap[chain_id] = (min_seq_no, max_seq_no) + break + + if not chain_to_gap: + # If there are no gaps to fetch, we're done! + return result + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + """ + + args = [ + (chain_id, min_no, max_no) + for chain_id, (min_no, max_no) in chain_to_gap.items() + ] + + rows = txn.execute_values(sql, args) + result.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? + """ + for chain_id, (min_no, max_no) in chain_to_gap.items(): + txn.execute(sql, (chain_id, min_no, max_no)) + result.update(r for r, in txn) + + return result + def _get_auth_chain_difference_txn( self, txn, state_sets: List[Set[str]] ) -> Set[str]: + """Calculates the auth chain difference using a breadth first search. + + This is used when we don't have a cover index for the room. + """ # Algorithm Description # ~~~~~~~~~~~~~~~~~~~~~ diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e5c03cc609..438383abe1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): VALUES (?, ?, ?, ?, ?, ?) """ - txn.executemany( + txn.execute_batch( sql, ( _gen_entry(user_id, actions) @@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ], ) - txn.executemany( + txn.execute_batch( """ UPDATE event_push_summary SET notif_count = ?, unread_count = ?, stream_ordering = ? @@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore): (rotate_to_stream_ordering,), ) + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" @@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 90fb1a1f00..ccda9f1caa 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,17 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -35,7 +45,7 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder -from synapse.util.iterutils import batch_iter +from synapse.util.iterutils import batch_iter, sorted_topologically if TYPE_CHECKING: from synapse.server import HomeServer @@ -366,6 +376,36 @@ class PersistEventsStore: # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _persist_event_auth_chain_txn( + self, txn: LoggingTransaction, events: List[EventBase], + ) -> None: + + # We only care about state events, so this if there are no state events. + if not any(e.is_state() for e in events): + return + # We want to store event_auth mappings for rejected events, as they're # used in state res v2. # This is only necessary if the rejected event appears in an accepted @@ -381,31 +421,467 @@ class PersistEventsStore: "room_id": event.room_id, "auth_id": auth_id, } - for event, _ in events_and_contexts + for event in events for auth_id in event.auth_event_ids() if event.is_state() ], ) - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + + # We ignore legacy rooms that we aren't filling the chain cover index + # for. + rows = self.db_pool.simple_select_many_txn( + txn, + table="rooms", + column="room_id", + iterable={event.room_id for event in events if event.is_state()}, + keyvalues={}, + retcols=("room_id", "has_auth_chain_index"), ) + rooms_using_chain_index = { + row["room_id"] for row in rows if row["has_auth_chain_index"] + } - # From this point onwards the events are only ones that weren't - # rejected. + state_events = { + event.event_id: event + for event in events + if event.is_state() and event.room_id in rooms_using_chain_index + } - self._update_metadata_tables_txn( + if not state_events: + return + + # We need to know the type/state_key and auth events of the events we're + # calculating chain IDs for. We don't rely on having the full Event + # instances as we'll potentially be pulling more events from the DB and + # we don't need the overhead of fetching/parsing the full event JSON. + event_to_types = { + e.event_id: (e.type, e.state_key) for e in state_events.values() + } + event_to_auth_chain = { + e.event_id: e.auth_event_ids() for e in state_events.values() + } + event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} + + self._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + @classmethod + def _add_chain_cover_index( + cls, + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + ) -> None: + """Calculate the chain cover index for the given events. + + Args: + event_to_room_id: Event ID to the room ID of the event + event_to_types: Event ID to type and state_key of the event + event_to_auth_chain: Event ID to list of auth event IDs of the + event (events with no auth events can be excluded). + """ + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] + + # Set of event IDs to calculate chain ID/seq numbers for. + events_to_calc_chain_id_for = set(event_to_room_id) + + # We check if there are any events that need to be handled in the rooms + # we're looking at. These should just be out of band memberships, where + # we didn't have the auth chain when we first persisted. + rows = db_pool.simple_select_many_txn( txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="room_id", + iterable=set(event_to_room_id.values()), + retcols=("event_id", "type", "state_key"), ) + for row in rows: + event_id = row["event_id"] + event_type = row["type"] + state_key = row["state_key"] + + # (We could pull out the auth events for all rows at once using + # simple_select_many, but this case happens rarely and almost always + # with a single row.) + auth_events = db_pool.simple_select_onecol_txn( + txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", + ) - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + events_to_calc_chain_id_for.add(event_id) + event_to_types[event_id] = (event_type, state_key) + event_to_auth_chain[event_id] = auth_events + + # First we get the chain ID and sequence numbers for the events' + # auth events (that aren't also currently being persisted). + # + # Note that there there is an edge case here where we might not have + # calculated chains and sequence numbers for events that were "out + # of band". We handle this case by fetching the necessary info and + # adding it to the set of events to calculate chain IDs for. + + missing_auth_chains = { + a_id + for auth_events in event_to_auth_chain.values() + for a_id in auth_events + if a_id not in events_to_calc_chain_id_for + } + + # We loop here in case we find an out of band membership and need to + # fetch their auth event info. + while missing_auth_chains: + sql = """ + SELECT event_id, events.type, state_key, chain_id, sequence_number + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + WHERE + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", missing_auth_chains, + ) + txn.execute(sql + clause, args) + + missing_auth_chains.clear() + + for auth_id, event_type, state_key, chain_id, sequence_number in txn: + event_to_types[auth_id] = (event_type, state_key) + + if chain_id is None: + # No chain ID, so the event was persisted out of band. + # We add to list of events to calculate auth chains for. + + events_to_calc_chain_id_for.add(auth_id) + + event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn( + txn, + "event_auth", + keyvalues={"event_id": auth_id}, + retcol="auth_id", + ) + + missing_auth_chains.update( + e + for e in event_to_auth_chain[auth_id] + if e not in event_to_types + ) + else: + chain_map[auth_id] = (chain_id, sequence_number) + + # Now we check if we have any events where we don't have auth chain, + # this should only be out of band memberships. + for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain): + for auth_id in event_to_auth_chain[event_id]: + if ( + auth_id not in chain_map + and auth_id not in events_to_calc_chain_id_for + ): + events_to_calc_chain_id_for.discard(event_id) + + # If this is an event we're trying to persist we add it to + # the list of events to calculate chain IDs for next time + # around. (Otherwise we will have already added it to the + # table). + room_id = event_to_room_id.get(event_id) + if room_id: + e_type, state_key = event_to_types[event_id] + db_pool.simple_insert_txn( + txn, + table="event_auth_chain_to_calculate", + values={ + "event_id": event_id, + "room_id": room_id, + "type": e_type, + "state_key": state_key, + }, + ) + + # We stop checking the event's auth events since we've + # discarded it. + break + + if not events_to_calc_chain_id_for: + return + + # Allocate chain ID/sequence numbers to each new event. + new_chain_tuples = cls._allocate_chain_ids( + txn, + db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, + events_to_calc_chain_id_for, + chain_map, + ) + chain_map.update(new_chain_tuples) + + db_pool.simple_insert_many_txn( + txn, + table="event_auth_chains", + values=[ + {"event_id": event_id, "chain_id": c_id, "sequence_number": seq} + for event_id, (c_id, seq) in new_chain_tuples.items() + ], + ) + + db_pool.simple_delete_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="event_id", + iterable=new_chain_tuples, + ) + + # Now we need to calculate any new links between chains caused by + # the new events. + # + # Links are pairs of chain ID/sequence numbers such that for any + # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain + # if and only if there is at least one link (CA, S1) -> (CB, S2) + # where SA >= S1 and S2 >= SB. + # + # We try and avoid adding redundant links to the table, e.g. if we + # have two links between two chains which both start/end at the + # sequence number event (or cross) then one can be safely dropped. + # + # To calculate new links we look at every new event and: + # 1. Fetch the chain ID/sequence numbers of its auth events, + # discarding any that are reachable by other auth events, or + # that have the same chain ID as the event. + # 2. For each retained auth event we: + # a. Add a link from the event's to the auth event's chain + # ID/sequence number; and + # b. Add a link from the event to every chain reachable by the + # auth event. + + # Step 1, fetch all existing links from all the chains we've seen + # referenced. + chain_links = _LinkMap() + rows = db_pool.simple_select_many_txn( + txn, + table="event_auth_chain_links", + column="origin_chain_id", + iterable={chain_id for chain_id, _ in chain_map.values()}, + keyvalues={}, + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + ) + for row in rows: + chain_links.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + new=False, + ) + + # We do this in toplogical order to avoid adding redundant links. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + chain_id, sequence_number = chain_map[event_id] + + # Filter out auth events that are reachable by other auth + # events. We do this by looking at every permutation of pairs of + # auth events (A, B) to check if B is reachable from A. + reduction = { + a_id + for a_id in event_to_auth_chain.get(event_id, []) + if chain_map[a_id][0] != chain_id + } + for start_auth_id, end_auth_id in itertools.permutations( + event_to_auth_chain.get(event_id, []), r=2, + ): + if chain_links.exists_path_from( + chain_map[start_auth_id], chain_map[end_auth_id] + ): + reduction.discard(end_auth_id) + + # Step 2, figure out what the new links are from the reduced + # list of auth events. + for auth_id in reduction: + auth_chain_id, auth_sequence_number = chain_map[auth_id] + + # Step 2a, add link between the event and auth event + chain_links.add_link( + (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) + ) + + # Step 2b, add a link to chains reachable from the auth + # event. + for target_id, target_seq in chain_links.get_links_from( + (auth_chain_id, auth_sequence_number) + ): + if target_id == chain_id: + continue + + chain_links.add_link( + (chain_id, sequence_number), (target_id, target_seq) + ) + + db_pool.simple_insert_many_txn( + txn, + table="event_auth_chain_links", + values=[ + { + "origin_chain_id": source_id, + "origin_sequence_number": source_seq, + "target_chain_id": target_id, + "target_sequence_number": target_seq, + } + for ( + source_id, + source_seq, + target_id, + target_seq, + ) in chain_links.get_additions() + ], + ) + + @staticmethod + def _allocate_chain_ids( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + events_to_calc_chain_id_for: Set[str], + chain_map: Dict[str, Tuple[int, int]], + ) -> Dict[str, Tuple[int, int]]: + """Allocates, but does not persist, chain ID/sequence numbers for the + events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index + for info on args) + """ + + # We now calculate the chain IDs/sequence numbers for the events. We do + # this by looking at the chain ID and sequence number of any auth event + # with the same type/state_key and incrementing the sequence number by + # one. If there was no match or the chain ID/sequence number is already + # taken we generate a new chain. + # + # We try to reduce the number of times that we hit the database by + # batching up calls, to make this more efficient when persisting large + # numbers of state events (e.g. during joins). + # + # We do this by: + # 1. Calculating for each event which auth event will be used to + # inherit the chain ID, i.e. converting the auth chain graph to a + # tree that we can allocate chains on. We also keep track of which + # existing chain IDs have been referenced. + # 2. Fetching the max allocated sequence number for each referenced + # existing chain ID, generating a map from chain ID to the max + # allocated sequence number. + # 3. Iterating over the tree and allocating a chain ID/seq no. to the + # new event, by incrementing the sequence number from the + # referenced event's chain ID/seq no. and checking that the + # incremented sequence number hasn't already been allocated (by + # looking in the map generated in the previous step). We generate a + # new chain if the sequence number has already been allocated. + # + + existing_chains = set() # type: Set[int] + tree = [] # type: List[Tuple[str, Optional[str]]] + + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events before + # the event itself. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + for auth_id in event_to_auth_chain.get(event_id, []): + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map.get(auth_id) + if existing_chain_id: + existing_chains.add(existing_chain_id[0]) + + tree.append((event_id, auth_id)) + break + else: + tree.append((event_id, None)) + + # Fetch the current max sequence number for each existing referenced chain. + sql = """ + SELECT chain_id, MAX(sequence_number) FROM event_auth_chains + WHERE %s + GROUP BY chain_id + """ + clause, args = make_in_list_sql_clause( + db_pool.engine, "chain_id", existing_chains + ) + txn.execute(sql % (clause,), args) + + chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + + # Allocate the new events chain ID/sequence numbers. + # + # To reduce the number of calls to the database we don't allocate a + # chain ID number in the loop, instead we use a temporary `object()` for + # each new chain ID. Once we've done the loop we generate the necessary + # number of new chain IDs in one call, replacing all temporary + # objects with real allocated chain IDs. + + unallocated_chain_ids = set() # type: Set[object] + new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + for event_id, auth_event_id in tree: + # If we reference an auth_event_id we fetch the allocated chain ID, + # either from the existing `chain_map` or the newly generated + # `new_chain_tuples` map. + existing_chain_id = None + if auth_event_id: + existing_chain_id = new_chain_tuples.get(auth_event_id) + if not existing_chain_id: + existing_chain_id = chain_map[auth_event_id] + + new_chain_tuple = None # type: Optional[Tuple[Any, int]] + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + + if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + # If we need to start a new chain we allocate a temporary chain ID. + if not new_chain_tuple: + new_chain_tuple = (object(), 1) + unallocated_chain_ids.add(new_chain_tuple[0]) + + new_chain_tuples[event_id] = new_chain_tuple + chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] + + # Generate new chain IDs for all unallocated chain IDs. + newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + txn, len(unallocated_chain_ids) + ) + + # Map from potentially temporary chain ID to real chain ID + chain_id_to_allocated_map = dict( + zip(unallocated_chain_ids, newly_allocated_chain_ids) + ) # type: Dict[Any, int] + chain_id_to_allocated_map.update((c, c) for c in existing_chains) + + return { + event_id: (chain_id_to_allocated_map[chain_id], seq) + for event_id, (chain_id, seq) in new_chain_tuples.items() + } def _persist_transaction_ids_txn( self, @@ -489,7 +965,7 @@ class PersistEventsStore: WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.executemany( + txn.execute_batch( sql, ( ( @@ -508,7 +984,7 @@ class PersistEventsStore: ) # Now we actually update the current_state_events table - txn.executemany( + txn.execute_batch( "DELETE FROM current_state_events" " WHERE room_id = ? AND type = ? AND state_key = ?", ( @@ -520,7 +996,7 @@ class PersistEventsStore: # We include the membership in the current state table, hence we do # a lookup when we insert. This assumes that all events have already # been inserted into room_memberships. - txn.executemany( + txn.execute_batch( """INSERT INTO current_state_events (room_id, type, state_key, event_id, membership) VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) @@ -540,7 +1016,7 @@ class PersistEventsStore: # we have no record of the fact the user *was* a member of the # room but got, say, state reset out of it. if to_delete or to_insert: - txn.executemany( + txn.execute_batch( "DELETE FROM local_current_membership" " WHERE room_id = ? AND user_id = ?", ( @@ -551,7 +1027,7 @@ class PersistEventsStore: ) if to_insert: - txn.executemany( + txn.execute_batch( """INSERT INTO local_current_membership (room_id, user_id, event_id, membership) VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) @@ -799,7 +1275,8 @@ class PersistEventsStore: return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _store_event_txn(self, txn, events_and_contexts): - """Insert new events into the event and event_json tables + """Insert new events into the event, event_json, redaction and + state_events tables. Args: txn (twisted.enterprise.adbapi.Connection): db connection @@ -871,6 +1348,29 @@ class PersistEventsStore: updatevalues={"have_censored": False}, ) + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] + + state_values = [] + for event, context in state_events_and_contexts: + vals = { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + + # TODO: How does this work with backfilling? + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state + + state_values.append(vals) + + self.db_pool.simple_insert_many_txn( + txn, table="state_events", values=state_values + ) + def _store_rejected_events_txn(self, txn, events_and_contexts): """Add rows to the 'rejections' table for received events which were rejected @@ -987,29 +1487,6 @@ class PersistEventsStore: txn, [event for event, _ in events_and_contexts] ) - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, context in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self.db_pool.simple_insert_many_txn( - txn, table="state_events", values=state_values - ) - # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1350,7 +1827,7 @@ class PersistEventsStore: """ if events_and_contexts: - txn.executemany( + txn.execute_batch( sql, ( ( @@ -1379,7 +1856,7 @@ class PersistEventsStore: # Now we delete the staging area for *all* events that were being # persisted. - txn.executemany( + txn.execute_batch( "DELETE FROM event_push_actions_staging WHERE event_id = ?", ((event.event_id,) for event, _ in all_events_and_contexts), ) @@ -1498,7 +1975,7 @@ class PersistEventsStore: " )" ) - txn.executemany( + txn.execute_batch( query, [ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) @@ -1512,7 +1989,7 @@ class PersistEventsStore: "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" ) - txn.executemany( + txn.execute_batch( query, [ (ev.event_id, ev.room_id) @@ -1520,3 +1997,131 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + +@attr.s(slots=True) +class _LinkMap: + """A helper type for tracking links between chains. + """ + + # Stores the set of links as nested maps: source chain ID -> target chain ID + # -> source sequence number -> target sequence number. + maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict) + + # Stores the links that have been added (with new set to true), as tuples of + # `(source chain ID, source sequence no, target chain ID, target sequence no.)` + additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set) + + def add_link( + self, + src_tuple: Tuple[int, int], + target_tuple: Tuple[int, int], + new: bool = True, + ) -> bool: + """Add a new link between two chains, ensuring no redundant links are added. + + New links should be added in topological order. + + Args: + src_tuple: The chain ID/sequence number of the source of the link. + target_tuple: The chain ID/sequence number of the target of the link. + new: Whether this is a "new" link, i.e. should it be returned + by `get_additions`. + + Returns: + True if a link was added, false if the given link was dropped as redundant + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {}) + + assert src_chain != target_chain + + if new: + # Check if the new link is redundant + for current_seq_src, current_seq_target in current_links.items(): + # If a link "crosses" another link then its redundant. For example + # in the following link 1 (L1) is redundant, as any event reachable + # via L1 is *also* reachable via L2. + # + # Chain A Chain B + # | | + # L1 |------ | + # | | | + # L2 |---- | -->| + # | | | + # | |--->| + # | | + # | | + # + # So we only need to keep links which *do not* cross, i.e. links + # that both start and end above or below an existing link. + # + # Note, since we add links in topological ordering we should never + # see `src_seq` less than `current_seq_src`. + + if current_seq_src <= src_seq and target_seq <= current_seq_target: + # This new link is redundant, nothing to do. + return False + + self.additions.add((src_chain, src_seq, target_chain, target_seq)) + + current_links[src_seq] = target_seq + return True + + def get_links_from( + self, src_tuple: Tuple[int, int] + ) -> Generator[Tuple[int, int], None, None]: + """Gets the chains reachable from the given chain/sequence number. + + Yields: + The chain ID and sequence number the link points to. + """ + src_chain, src_seq = src_tuple + for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): + for link_src_seq, target_seq in sequence_numbers.items(): + if link_src_seq <= src_seq: + yield target_id, target_seq + + def get_links_between( + self, source_chain: int, target_chain: int + ) -> Generator[Tuple[int, int], None, None]: + """Gets the links between two chains. + + Yields: + The source and target sequence numbers. + """ + + yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() + + def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: + """Gets any newly added links. + + Yields: + The source chain ID/sequence number and target chain ID/sequence number + """ + + for src_chain, src_seq, target_chain, _ in self.additions: + target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq) + if target_seq is not None: + yield (src_chain, src_seq, target_chain, target_seq) + + def exists_path_from( + self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], + ) -> bool: + """Checks if there is a path between the source chain ID/sequence and + target chain ID/sequence. + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + if src_chain == target_chain: + return target_seq <= src_seq + + links = self.get_links_between(src_chain, target_chain) + for link_start_seq, link_end_seq in links: + if link_start_seq <= src_seq and target_seq <= link_end_seq: + return True + + return False diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 97b6754846..5ca4fa6817 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -14,14 +14,41 @@ # limitations under the License. import logging +from typing import Dict, List, Optional, Tuple + +import attr from synapse.api.constants import EventContentFields +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, make_tuple_comparison_clause +from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.types import Cursor +from synapse.types import JsonDict logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True) +class _CalculateChainCover: + """Return value for _calculate_chain_cover_txn. + """ + + # The last room_id/depth/stream processed. + room_id = attr.ib(type=str) + depth = attr.ib(type=int) + stream = attr.ib(type=int) + + # Number of rows processed + processed_count = attr.ib(type=int) + + # Map from room_id to last depth/stream processed for each room that we have + # processed all events for (i.e. the rooms we can flip the + # `has_auth_chain_index` for) + finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) + + class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" @@ -99,13 +126,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): columns=["user_id", "created_ts"], ) + self.db_pool.updates.register_background_update_handler( + "rejected_events_metadata", self._rejected_events_metadata, + ) + + self.db_pool.updates.register_background_update_handler( + "chain_cover", self._chain_cover_index, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_txn(txn): sql = ( "SELECT stream_ordering, event_id, json FROM events" @@ -143,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, update_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -175,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_search_txn(txn): sql = ( "SELECT stream_ordering, event_id FROM events" @@ -221,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, rows_to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -582,3 +609,314 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.updates._end_background_update("event_store_labels") return num_rows + + async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int: + """Adds rejected events to the `state_events` and `event_auth` metadata + tables. + """ + + last_event_id = progress.get("last_event_id", "") + + def get_rejected_events( + txn: Cursor, + ) -> List[Tuple[str, str, JsonDict, bool, bool]]: + # Fetch rejected event json, their room version and whether we have + # inserted them into the state_events or auth_events tables. + # + # Note we can assume that events that don't have a corresponding + # room version are V1 rooms. + sql = """ + SELECT DISTINCT + event_id, + COALESCE(room_version, '1'), + json, + state_events.event_id IS NOT NULL, + event_auth.event_id IS NOT NULL + FROM rejections + INNER JOIN event_json USING (event_id) + LEFT JOIN rooms USING (room_id) + LEFT JOIN state_events USING (event_id) + LEFT JOIN event_auth USING (event_id) + WHERE event_id > ? + ORDER BY event_id + LIMIT ? + """ + + txn.execute(sql, (last_event_id, batch_size,)) + + return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore + + results = await self.db_pool.runInteraction( + desc="_rejected_events_metadata_get", func=get_rejected_events + ) + + if not results: + await self.db_pool.updates._end_background_update( + "rejected_events_metadata" + ) + return 0 + + state_events = [] + auth_events = [] + for event_id, room_version, event_json, has_state, has_event_auth in results: + last_event_id = event_id + + if has_state and has_event_auth: + continue + + room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version) + if not room_version_obj: + # We no longer support this room version, so we just ignore the + # events entirely. + logger.info( + "Ignoring event with unknown room version %r: %r", + room_version, + event_id, + ) + continue + + event = make_event_from_dict(event_json, room_version_obj) + + if not event.is_state(): + continue + + if not has_state: + state_events.append( + { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + ) + + if not has_event_auth: + for auth_id in event.auth_event_ids(): + auth_events.append( + { + "room_id": event.room_id, + "event_id": event.event_id, + "auth_id": auth_id, + } + ) + + if state_events: + await self.db_pool.simple_insert_many( + table="state_events", + values=state_events, + desc="_rejected_events_metadata_state_events", + ) + + if auth_events: + await self.db_pool.simple_insert_many( + table="event_auth", + values=auth_events, + desc="_rejected_events_metadata_event_auth", + ) + + await self.db_pool.updates._background_update_progress( + "rejected_events_metadata", {"last_event_id": last_event_id} + ) + + if len(results) < batch_size: + await self.db_pool.updates._end_background_update( + "rejected_events_metadata" + ) + + return len(results) + + async def _chain_cover_index(self, progress: dict, batch_size: int) -> int: + """A background updates that iterates over all rooms and generates the + chain cover index for them. + """ + + current_room_id = progress.get("current_room_id", "") + + # Where we've processed up to in the room, defaults to the start of the + # room. + last_depth = progress.get("last_depth", -1) + last_stream = progress.get("last_stream", -1) + + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + current_room_id, + last_depth, + last_stream, + batch_size, + single_room=False, + ) + + finished = result.processed_count == 0 + + total_rows_processed = result.processed_count + current_room_id = result.room_id + last_depth = result.depth + last_stream = result.stream + + for room_id, (depth, stream) in result.finished_room_map.items(): + # If we've done all the events in the room we flip the + # `has_auth_chain_index` in the DB. Note that its possible for + # further events to be persisted between the above and setting the + # flag without having the chain cover calculated for them. This is + # fine as a) the code gracefully handles these cases and b) we'll + # calculate them below. + + await self.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": True}, + desc="_chain_cover_index", + ) + + # Handle any events that might have raced with us flipping the + # bit above. + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + room_id, + depth, + stream, + batch_size=None, + single_room=True, + ) + + total_rows_processed += result.processed_count + + if finished: + await self.db_pool.updates._end_background_update("chain_cover") + return total_rows_processed + + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + }, + ) + + return total_rows_processed + + def _calculate_chain_cover_txn( + self, + txn: Cursor, + last_room_id: str, + last_depth: int, + last_stream: int, + batch_size: Optional[int], + single_room: bool, + ) -> _CalculateChainCover: + """Calculate the chain cover for `batch_size` events, ordered by + `(room_id, depth, stream)`. + + Args: + txn, + last_room_id, last_depth, last_stream: The `(room_id, depth, stream)` + tuple to fetch results after. + batch_size: The maximum number of events to process. If None then + no limit. + single_room: Whether to calculate the index for just the given + room. + """ + + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("events.room_id", last_room_id), + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + extra_clause = "" + if single_room: + extra_clause = "AND events.room_id = ?" + tuple_args.append(last_room_id) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering, + events.room_id + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + %(extra)s + ORDER BY events.room_id, topological_ordering, stream_ordering + %(limit)s + """ % { + "tuple_cmp": tuple_clause, + "limit": "LIMIT ?" if batch_size is not None else "", + "extra": extra_clause, + } + + if batch_size is not None: + tuple_args.append(batch_size) + + txn.execute(sql, tuple_args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: row[5] for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + # Calculate the new last position we've processed up to. + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + new_last_room_id = rows[-1][5] if rows else "" # type: str + + # Map from room_id to last depth/stream_ordering processed for the room, + # excluding the last room (which we're likely still processing). We also + # need to include the room passed in if it's not included in the result + # set (as we then know we've processed all events in said room). + # + # This is the set of rooms that we can now safely flip the + # `has_auth_chain_index` bit for. + finished_rooms = { + row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id + } + if last_room_id not in finished_rooms and last_room_id != new_last_room_id: + finished_rooms[last_room_id] = (last_depth, last_stream) + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + PersistEventsStore._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + return _CalculateChainCover( + room_id=new_last_room_id, + depth=new_last_depth, + stream=new_last_stream, + processed_count=count, + finished_room_map=finished_rooms, + ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4732685f6e..71d823be72 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="events", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_stream_seq", writers=hs.config.worker.writers.events, ) @@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="backfill", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_backfill_stream_seq", positive=False, writers=hs.config.worker.writers.events, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4b2f224718..e017177655 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_local_media_before( self, before_ts: int, size_gt: int, keep_profiles: bool, - ) -> Optional[List[str]]: + ) -> List[str]: # to find files that have never been accessed (last_access_ts IS NULL) # compare with `created_ts` @@ -416,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_origin = ? AND media_id = ?" ) - txn.executemany( + txn.execute_batch( sql, ( (time_ms, media_origin, media_id) @@ -429,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_id = ?" ) - txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn @@ -556,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" def _delete_url_cache_txn(txn): - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache", _delete_url_cache_txn @@ -585,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_media_txn(txn): sql = "DELETE FROM local_media_repository WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 0e25ca3d7a..54ef0f1f54 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: str + self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 5d668aadb2..ecfc9f20b1 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): ) # Update backward extremeties - txn.executemany( + txn.execute_batch( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", [(room_id, event_id) for event_id, in new_backwards_extrems], diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 77ba9d819e..bc7621b8d6 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -17,14 +17,13 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple -from canonicaljson import encode_canonical_json - from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.types import Connection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore): "device_display_name": device_display_name, "ts": pushkey_ts, "lang": lang, - "data": bytearray(encode_canonical_json(data)), + "data": json_encoder.encode(data), "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 1e7949a323..e4843a202c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_max_receipt_stream_id` which can be called in the initializer. - """ - +class ReceiptsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_receipts = ( + self._instance_name in hs.config.worker.writers.receipts + ) + + self._receipts_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="receipts", + instance_name=self._instance_name, + tables=[("receipts_linearized", "instance_name", "stream_id")], + sequence_name="receipts_sequence", + writers=hs.config.worker.writers.receipts, + ) + else: + self._can_write_to_receipts = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + else: + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ - raise NotImplementedError() + return self._receipts_id_gen.get_current_token() @cached() async def get_users_with_read_receipts_in_room(self, room_id): @@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - -class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" + def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self.get_last_receipt_event_id_for_user.invalidate( + (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() + return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( self, txn, room_id, receipt_type, user_id, event_id, data, stream_id @@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ + assert self._can_write_to_receipts + res = self.db_pool.simple_select_one_txn( txn, table="events", @@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore): ) return None - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + self.invalidate_caches_for_receipt, room_id, receipt_type, user_id ) txn.call_after( self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) - txn.call_after( - self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type), - ) - self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", @@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore): Automatically does conversion between linearized and graph representations. """ + assert self._can_write_to_receipts + if not event_ids: return None @@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore): async def insert_graph_receipt( self, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, @@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore): def insert_graph_receipt_txn( self, txn, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, @@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "data": json_encoder.encode(data), }, ) + + +class ReceiptsStore(ReceiptsWorkerStore): + pass diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8d05288ed4..585b4049d6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1104,7 +1104,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): FROM user_threepids """ - txn.executemany(sql, [(id_server,) for id_server in id_servers]) + txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) if id_servers: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4650d0689b..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore): return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), + retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), desc="get_room", allow_none=True, ) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) @@ -1166,6 +1164,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore): # It's overridden by RoomStore for the synapse master. raise NotImplementedError() + async def has_auth_chain_index(self, room_id: str) -> bool: + """Check if the room has (or can have) a chain cover index. + + Defaults to True if we don't have an entry in `rooms` table nor any + events for the room. + """ + + has_auth_chain_index = await self.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="has_auth_chain_index", + desc="has_auth_chain_index", + allow_none=True, + ) + + if has_auth_chain_index: + return True + + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + max_ordering = await self.db_pool.simple_select_one_onecol( + table="events", + keyvalues={"room_id": room_id}, + retcol="MAX(stream_ordering)", + allow_none=True, + desc="upsert_room_on_join", + ) + + return max_ordering is None + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1179,12 +1208,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Called when we join a room over federation, and overwrites any room version currently in the table. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, values={"room_version": room_version.identifier}, - insertion_values={"is_public": False, "creator": ""}, + insertion_values={ + "is_public": False, + "creator": "", + "has_auth_chain_index": has_auth_chain_index, + }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. lock=False, @@ -1219,6 +1257,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "creator": room_creator_user_id, "is_public": is_public, "room_version": room_version.identifier, + "has_auth_chain_index": True, }, ) if is_public: @@ -1247,6 +1286,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="maybe_store_room_on_outlier_membership", table="rooms", @@ -1256,6 +1300,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "room_version": room_version.identifier, "is_public": False, "creator": "", + "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index dcdaf09682..92382bed28 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): "max_stream_id_exclusive", self._stream_order_on_start + 1 ) - INSERT_CLUMP_SIZE = 1000 - def add_membership_profile_txn(txn): sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json @@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): UPDATE room_memberships SET display_name = ?, avatar_url = ? WHERE event_id = ? AND room_id = ? """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) + txn.execute_batch(to_update_sql, to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres new file mode 100644 index 0000000000..de57645019 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE access_tokens DROP COLUMN last_used; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite new file mode 100644 index 0000000000..ee0e3521bf --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite @@ -0,0 +1,62 @@ +/* + * Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Dropping last_used column from access_tokens table. + +CREATE TABLE access_tokens2 ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + valid_until_ms BIGINT, + puppets_user_id TEXT, + last_validated BIGINT, + UNIQUE(token) +); + +INSERT INTO access_tokens2(id, user_id, device_id, token) + SELECT id, user_id, device_id, token FROM access_tokens; + +DROP TABLE access_tokens; +ALTER TABLE access_tokens2 RENAME TO access_tokens; + +CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id); + + +-- Re-adding foreign key reference in event_txn_id table + +CREATE TABLE event_txn_id2 ( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + token_id BIGINT NOT NULL, + txn_id TEXT NOT NULL, + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (token_id) + REFERENCES access_tokens (id) ON DELETE CASCADE +); + +INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts) + SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id; + +DROP TABLE event_txn_id; +ALTER TABLE event_txn_id2 RENAME TO event_txn_id; + +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id); +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id); +CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts); \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql new file mode 100644 index 0000000000..9c95646281 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5828, 'rejected_events_metadata', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py index f35c70b699..9e8f35c1d2 100644 --- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py +++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py @@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs # { "ignored_users": "@someone:example.org": {} } ignored_users = content.get("ignored_users", {}) if isinstance(ignored_users, dict) and ignored_users: - cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) + cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users]) # Add indexes after inserting data for efficiency. logger.info("Adding constraints to ignored_users table") diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql new file mode 100644 index 0000000000..729196cfd5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql @@ -0,0 +1,52 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- See docs/auth_chain_difference_algorithm.md + +CREATE TABLE event_auth_chains ( + event_id TEXT PRIMARY KEY, + chain_id BIGINT NOT NULL, + sequence_number BIGINT NOT NULL +); + +CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number); + + +CREATE TABLE event_auth_chain_links ( + origin_chain_id BIGINT NOT NULL, + origin_sequence_number BIGINT NOT NULL, + + target_chain_id BIGINT NOT NULL, + target_sequence_number BIGINT NOT NULL +); + + +CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id); + + +-- Events that we have persisted but not calculated auth chains for, +-- e.g. out of band memberships (where we don't have the auth chain) +CREATE TABLE event_auth_chain_to_calculate ( + event_id TEXT PRIMARY KEY, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL +); + +CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id); + + +-- Whether we've calculated the above index for a room. +ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN; diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres new file mode 100644 index 0000000000..e8a035bbeb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id; diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql new file mode 100644 index 0000000000..64ab696cfe --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is no longer used and was only kept until we bumped the schema version. +DROP TABLE IF EXISTS account_data_max_stream_id; diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql new file mode 100644 index 0000000000..fb71b360a0 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is no longer used and was only kept until we bumped the schema version. +DROP TABLE IF EXISTS cache_invalidation_stream; diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql new file mode 100644 index 0000000000..fe3dca71dd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5906, 'chain_cover', '{}', 'rejected_events_metadata'); diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql new file mode 100644 index 0000000000..46abf8d562 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_account_data ADD COLUMN instance_name TEXT; +ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT; +ALTER TABLE account_data ADD COLUMN instance_name TEXT; + +ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres new file mode 100644 index 0000000000..4a6e6c74f5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS account_data_sequence; + +-- We need to take the max across all the account_data tables as they share the +-- ID generator +SELECT setval('account_data_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data), + (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions), + (SELECT COALESCE(MAX(stream_id), 1) FROM account_data) + ) +)); + +CREATE SEQUENCE IF NOT EXISTS receipts_sequence; + +SELECT setval('receipts_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized +)); diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql new file mode 100644 index 0000000000..9f2b5ebc5a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We incorrectly populated these, so we delete them and let the +-- MultiWriterIdGenerator repopulate it. +DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data'; diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e34fce6281..871af64b11 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( @@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 9f120d3cb6..50067eabfc 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore): ) return {row["tag"]: db_to_json(row["content"]) for row in rows} - -class TagsStore(TagsWorkerStore): async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: @@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): @@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data def remove_tag_txn(txn, next_id): sql = ( @@ -250,21 +251,12 @@ class TagsStore(TagsWorkerStore): room_id: The ID of the room. next_id: The the revision to advance to. """ + assert self._can_write_to_account_data txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id ) - # Note: This is only here for backwards compat to allow admins to - # roll back to a previous Synapse version. Next time we update the - # database version we can remove this table. - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" - ) - txn.execute(update_max_id_sql, (next_id, next_id)) - update_sql = ( "UPDATE room_tags_revisions" " SET stream_id = ?" @@ -288,3 +280,7 @@ class TagsStore(TagsWorkerStore): # which stream_id ends up in the table, as long as it is higher # than the id that the client has. pass + + +class TagsStore(TagsWorkerStore): + pass diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 59207cadd4..cea595ff19 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore): txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] ) -> List[str]: q = """ - SELECT destination FROM destinations - WHERE destination IN ( - SELECT destination FROM destination_rooms - WHERE destination_rooms.stream_ordering > - destinations.last_successful_stream_ordering - ) - AND destination > ? - AND ( - retry_last_ts IS NULL OR - retry_last_ts + retry_interval < ? - ) - ORDER BY destination - LIMIT 25 + SELECT DISTINCT destination FROM destinations + INNER JOIN destination_rooms USING (destination) + WHERE + stream_ordering > last_successful_stream_ordering + AND destination > ? + AND ( + retry_last_ts IS NULL OR + retry_last_ts + retry_interval < ? + ) + ORDER BY destination + LIMIT 25 """ txn.execute( q, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0e31cc811a..89cdc84a9c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) logger.info("[purge] removing redundant state groups") - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups_state WHERE state_group = ?", ((sg,) for sg in state_groups_to_delete), ) - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups WHERE id = ?", ((sg,) for sg in state_groups_to_delete), ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 01efb2cabb..566ea19bae 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -35,9 +35,6 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -# XXX: If you're about to bump this to 59 (or higher) please create an update -# that drops the unused `cache_invalidation_stream` table, as per #7436! -# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656! SCHEMA_VERSION = 59 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 133c0e7a28..71ef5a72dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -15,12 +15,11 @@ import heapq import logging import threading -from collections import deque +from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import attr -from typing_extensions import Deque from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction @@ -101,7 +100,13 @@ class StreamIdGenerator: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() # type: Deque[int] + + # We use this as an ordered set, as we want to efficiently append items, + # remove items and get the first item. Since we insert IDs in order, the + # insertion ordering will ensure its in the correct ordering. + # + # The key and values are the same, but we never look at the values. + self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] def get_next(self): """ @@ -113,7 +118,7 @@ class StreamIdGenerator: self._current += self._step next_id = self._current - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -121,7 +126,7 @@ class StreamIdGenerator: yield next_id finally: with self._lock: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -140,7 +145,7 @@ class StreamIdGenerator: self._current += n * self._step for next_id in next_ids: - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -149,7 +154,7 @@ class StreamIdGenerator: finally: with self._lock: for next_id in next_ids: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -162,7 +167,7 @@ class StreamIdGenerator: """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._step + return next(iter(self._unfinished_ids)) - self._step return self._current @@ -186,11 +191,12 @@ class MultiWriterIdGenerator: Args: db_conn db - stream_name: A name for the stream. + stream_name: A name for the stream, for use in the `stream_positions` + table. (Does not need to be the same as the replication stream name) instance_name: The name of this instance. - table: Database table associated with stream. - instance_column: Column that stores the row's writer's instance name - id_column: Column that stores the stream ID. + tables: List of tables associated with the stream. Tuple of table + name, column name that stores the writer's instance name, and + column name that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. writers: A list of known writers to use to populate current positions @@ -206,9 +212,7 @@ class MultiWriterIdGenerator: db: DatabasePool, stream_name: str, instance_name: str, - table: str, - instance_column: str, - id_column: str, + tables: List[Tuple[str, str, str]], sequence_name: str, writers: List[str], positive: bool = True, @@ -260,15 +264,20 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. - self._sequence_gen.check_consistency( - db_conn, table=table, id_column=id_column, positive=positive - ) + for table, _, id_column in tables: + self._sequence_gen.check_consistency( + db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, + ) # This goes and fills out the above state from the database. - self._load_current_ids(db_conn, table, instance_column, id_column) + self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, table: str, instance_column: str, id_column: str + self, db_conn, tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -306,17 +315,22 @@ class MultiWriterIdGenerator: # We add a GREATEST here to ensure that the result is always # positive. (This can be a problem for e.g. backfill streams where # the server has never backfilled). - sql = """ - SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) - FROM %(table)s - """ % { - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - cur.execute(sql) - (stream_id,) = cur.fetchone() - self._persisted_upto_position = stream_id + max_stream_id = 1 + for table, _, id_column in tables: + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + + max_stream_id = max(max_stream_id, stream_id) + + self._persisted_upto_position = max_stream_id else: # If we have a min_stream_id then we pull out everything greater # than it from the DB so that we can prefill @@ -329,21 +343,28 @@ class MultiWriterIdGenerator: # stream positions table before restart (or the stream position # table otherwise got out of date). - sql = """ - SELECT %(instance)s, %(id)s FROM %(table)s - WHERE ? %(cmp)s %(id)s - """ % { - "id": id_column, - "table": table, - "instance": instance_column, - "cmp": "<=" if self._positive else ">=", - } - cur.execute(sql, (min_stream_id * self._return_factor,)) - self._persisted_upto_position = min_stream_id + rows = [] + for table, instance_column, id_column in tables: + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + cur.execute(sql, (min_stream_id * self._return_factor,)) + + rows.extend(cur) + + # Sort so that we handle rows in order for each instance. + rows.sort() + with self._lock: - for (instance, stream_id,) in cur: + for (instance, stream_id,) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 4386b6101e..0ec4dc2918 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -15,9 +15,8 @@ import abc import logging import threading -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional -from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import ( BaseDatabaseEngine, IncorrectDatabaseSetup, @@ -25,6 +24,9 @@ from synapse.storage.engines import ( ) from synapse.storage.types import Connection, Cursor +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection + logger = logging.getLogger(__name__) @@ -43,6 +45,21 @@ and run the following SQL: See docs/postgres.md for more information. """ +_INCONSISTENT_STREAM_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated stream position +of '%(stream_name)s' in the 'stream_positions' table. + +This is likely a programming error and should be reported at +https://github.com/matrix-org/synapse. + +A temporary workaround to fix this error is to shut down Synapse (including +any and all workers) and run the following SQL: + + DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s'; + +This will need to be done every time the server is restarted. +""" + class SequenceGenerator(metaclass=abc.ABCMeta): """A class which generates a unique sequence of integers""" @@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta): ... @abc.abstractmethod + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + """Get the next `n` IDs in the sequence""" + ... + + @abc.abstractmethod def check_consistency( self, - db_conn: LoggingDatabaseConnection, + db_conn: "LoggingDatabaseConnection", table: str, id_column: str, + stream_name: Optional[str] = None, positive: bool = True, ): """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. - This is to handle various cases where the sequence value can get out - of sync with the table, e.g. if Synapse gets rolled back to a previous + This is to handle various cases where the sequence value can get out of + sync with the table, e.g. if Synapse gets rolled back to a previous version and the rolled forwards again. + + If a stream name is given then this will check that any value in the + `stream_positions` table is less than or equal to the current sequence + value. If it isn't then it's likely that streams have been crossed + somewhere (e.g. two ID generators have the same stream name). """ ... @@ -88,11 +116,15 @@ class PostgresSequenceGenerator(SequenceGenerator): def check_consistency( self, - db_conn: LoggingDatabaseConnection, + db_conn: "LoggingDatabaseConnection", table: str, id_column: str, + stream_name: Optional[str] = None, positive: bool = True, ): + """See SequenceGenerator.check_consistency for docstring. + """ + txn = db_conn.cursor(txn_name="sequence.check_consistency") # First we get the current max ID from the table. @@ -116,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator): "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) last_value, is_called = txn.fetchone() + + # If we have an associated stream check the stream_positions table. + max_in_stream_positions = None + if stream_name: + txn.execute( + "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?", + (stream_name,), + ) + row = txn.fetchone() + if row: + max_in_stream_positions = row[0] + txn.close() # If `is_called` is False then `last_value` is actually the value that @@ -136,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator): % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} ) + # If we have values in the stream positions table then they have to be + # less than or equal to `last_value` + if max_in_stream_positions and max_in_stream_positions > last_value: + raise IncorrectDatabaseSetup( + _INCONSISTENT_STREAM_ERROR + % {"seq": self._sequence_name, "stream_name": stream_name} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -172,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + first_id = self._current_max_id + 1 + self._current_max_id += n + return [first_id + i for i in range(n)] + def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: Connection, + table: str, + id_column: str, + stream_name: Optional[str] = None, + positive: bool = True, ): # There is nothing to do for in memory sequences pass diff --git a/synapse/types.py b/synapse/types.py index c7d4e95809..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,6 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService @@ -257,8 +258,13 @@ class DomainSpecificString( @classmethod def is_valid(cls: Type[DS], s: str) -> bool: + """Parses the input string and attempts to ensure it is valid.""" try: - cls.from_string(s) + obj = cls.from_string(s) + # Apply additional validation to the domain. This is only done + # during is_valid (and not part of from_string) since it is + # possible for invalid data to exist in room-state, etc. + parse_and_validate_server_name(obj.domain) return True except Exception: return False diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 06faeebe7f..6ef2b008a4 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -13,8 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import heapq from itertools import islice -from typing import Iterable, Iterator, Sequence, Tuple, TypeVar +from typing import ( + Dict, + Generator, + Iterable, + Iterator, + Mapping, + Sequence, + Set, + Tuple, + TypeVar, +) + +from synapse.types import Collection T = TypeVar("T") @@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: If the input is empty, no chunks are returned. """ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) + + +def sorted_topologically( + nodes: Iterable[T], graph: Mapping[T, Collection[T]], +) -> Generator[T, None, None]: + """Given a set of nodes and a graph, yield the nodes in toplogical order. + + For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`. + """ + + # This is implemented by Kahn's algorithm. + + degree_map = {node: 0 for node in nodes} + reverse_graph = {} # type: Dict[T, Set[T]] + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in edges: + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + heapq.heapify(zero_degree) + + while zero_degree: + node = heapq.heappop(zero_degree) + yield node + + for edge in reverse_graph.get(node, []): + if edge in degree_map: + degree_map[edge] -= 1 + if degree_map[edge] == 0: + heapq.heappush(zero_degree, edge) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 61d96a6c28..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. @@ -75,3 +167,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: if len(items) <= maxitems: return str(items) return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to True or False + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + + This is lifted from distutils.util.strtobool, with the exception that it actually + returns a bool, rather than an int. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError("invalid truth value %r" % (val,)) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index bd7a1b6891..c37bb6440e 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -118,4 +118,4 @@ class CasHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "get_user_agent"]) + return Mock(spec=["getClientIP", "getHeader"]) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index f5df657814..b3dfa40d25 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,20 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import re -from typing import Dict -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Optional +from urllib.parse import parse_qs, urlparse from mock import ANY, Mock, patch import pymacaroons -from twisted.web.resource import Resource - -from synapse.api.errors import RedirectException from synapse.handlers.sso import MappingException -from synapse.rest.client.v1 import login -from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.server import HomeServer from synapse.types import UserID @@ -151,6 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() + self.provider = self.handler._providers["oidc"] sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -162,9 +157,10 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def metadata_edit(self, values): - return patch.dict(self.handler._provider_metadata, values) + return patch.dict(self.provider._provider_metadata, values) def assertRenderedError(self, error, error_description=None): + self.render_error.assert_called_once() args = self.render_error.call_args[0] self.assertEqual(args[1], error) if error_description is not None: @@ -175,15 +171,15 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" - self.assertEqual(self.handler._callback_url, CALLBACK_URL) - self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) - self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + self.assertEqual(self.provider._callback_url, CALLBACK_URL) + self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = self.get_success(self.handler.load_metadata()) + metadata = self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -195,47 +191,47 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = self.get_success(self.handler.load_jwks()) + jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cachedโฆ self.http_client.reset_mock() - self.get_success(self.handler.load_jwks()) + self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_not_called() # โฆunless forced self.http_client.reset_mock() - self.get_success(self.handler.load_jwks(force=True)) + self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - self.get_failure(self.handler.load_jwks(force=True), RuntimeError) + self.get_failure(self.provider.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used - self.handler._scopes = [] # not asking the openid scope + self.provider._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = self.get_success(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @override_config({"oidc_config": COMMON_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" - h = self.handler + h = self.provider # Default test config does not throw h._validate_metadata() @@ -314,13 +310,13 @@ class OidcHandlerTestCase(HomeserverTestCase): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw - self.handler._validate_metadata() + self.provider._validate_metadata() def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) url = self.get_success( - self.handler.handle_redirect_request(req, b"http://client/redirect") + self.provider.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) @@ -349,9 +345,13 @@ class OidcHandlerTestCase(HomeserverTestCase): cookie = args[1] macaroon = pymacaroons.Macaroon.deserialize(cookie) - state = self.handler._get_value_from_macaroon(macaroon, "state") - nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") - redirect = self.handler._get_value_from_macaroon( + state = self.handler._token_generator._get_value_from_macaroon( + macaroon, "state" + ) + nonce = self.handler._token_generator._get_value_from_macaroon( + macaroon, "nonce" + ) + redirect = self.handler._token_generator._get_value_from_macaroon( macaroon, "client_redirect_url" ) @@ -384,7 +384,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # ensure that we are correctly testing the fallback when "get_extra_attributes" # is not implemented. - mapping_provider = self.handler._user_mapping_provider + mapping_provider = self.provider._user_mapping_provider with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes @@ -399,9 +399,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -411,12 +411,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - session = self.handler._generate_oidc_session_token( - state=state, - nonce=nonce, - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, - ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) request = _build_callback_request( code, state, session, user_agent=user_agent, ip_address=ip_address ) @@ -426,14 +421,14 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.assert_called_once_with( expected_user_id, request, client_redirect_url, None, ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._fetch_userinfo.assert_not_called() + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors with patch.object( - self.handler, + self.provider, "_remote_id_from_userinfo", new=Mock(side_effect=MappingException()), ): @@ -441,36 +436,36 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.handler._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") auth_handler.complete_sso_login.reset_mock() - self.handler._exchange_code.reset_mock() - self.handler._parse_id_token.reset_mock() - self.handler._fetch_userinfo.reset_mock() + self.provider._exchange_code.reset_mock() + self.provider._parse_id_token.reset_mock() + self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.handler._scopes = [] # do not ask the "openid" scope + self.provider._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( expected_user_id, request, client_redirect_url, None, ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_not_called() - self.handler._fetch_userinfo.assert_called_once_with(token) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_not_called() + self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() # Handle userinfo fetching error - self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc_handler import OidcError - self.handler._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) @@ -500,11 +495,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("invalid_session") # Mismatching session - session = self.handler._generate_oidc_session_token( - state="state", - nonce="nonce", - client_redirect_url="http://client/redirect", - ui_auth_session_id=None, + session = self._generate_oidc_session_token( + state="state", nonce="nonce", client_redirect_url="http://client/redirect", ) request.args = {} request.args[b"state"] = [b"mismatching state"] @@ -528,7 +520,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = self.get_success(self.handler._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -552,7 +544,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) from synapse.handlers.oidc_handler import OidcError - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -562,7 +554,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body @@ -574,14 +566,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field @@ -590,7 +582,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @override_config( @@ -616,18 +608,15 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "foo", "phone": "1234567", } - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() state = "state" client_redirect_url = "http://client/redirect" - session = self.handler._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, + session = self._generate_oidc_session_token( + state=state, nonce="nonce", client_redirect_url=client_redirect_url, ) request = _build_callback_request("code", state, session) @@ -841,116 +830,25 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.assertRenderedError("mapping_error", "localpart is invalid: ") + def _generate_oidc_session_token( + self, + state: str, + nonce: str, + client_redirect_url: str, + ui_auth_session_id: Optional[str] = None, + ) -> str: + from synapse.handlers.oidc_handler import OidcSessionData -class UsernamePickerTestCase(HomeserverTestCase): - if not HAS_OIDC: - skip = "requires OIDC" - - servlets = [login.register_servlets] - - def default_config(self): - config = super().default_config() - config["public_baseurl"] = BASE_URL - oidc_config = { - "enabled": True, - "client_id": CLIENT_ID, - "client_secret": CLIENT_SECRET, - "issuer": ISSUER, - "scopes": SCOPES, - "user_mapping_provider": { - "config": {"display_name_template": "{{ user.displayname }}"} - }, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - oidc_config.update(config.get("oidc_config", {})) - config["oidc_config"] = oidc_config - - # whitelist this client URI so we redirect straight to it rather than - # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - return d - - def test_username_picker(self): - """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" - - # first of all, mock up an OIDC callback to the OidcHandler, which should - # raise a RedirectException - userinfo = {"sub": "tester", "displayname": "Jonny"} - f = self.get_failure( - _make_callback_with_userinfo( - self.hs, userinfo, client_redirect_url=client_redirect_url + return self.handler._token_generator.generate_oidc_session_token( + state=state, + session_data=OidcSessionData( + idp_id="oidc", + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=ui_auth_session_id, ), - RedirectException, - ) - - # check the Location and cookies returned by the RedirectException - self.assertEqual(f.value.location, b"/_synapse/client/pick_username") - cookieheader = f.value.cookies[0] - regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") - m = regex.search(cookieheader) - if not m: - self.fail("cookie header %s does not match %s" % (cookieheader, regex)) - - # introspect the sso handler a bit to check that the username mapping session - # looks ok. - session_id = m.group(1).decode("ascii") - username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions - self.assertIn( - session_id, username_mapping_sessions, "session id not found in map" - ) - session = username_mapping_sessions[session_id] - self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) - - # the expiry time should be about 15 minutes away - expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) - self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - - # Now, submit a username to the username picker, which should serve a redirect - # back to the client - submit_path = f.value.location + b"/submit" - content = urlencode({b"username": b"bobby"}).encode("utf8") - chan = self.make_request( - "POST", - path=submit_path, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", cookieheader), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url ) - # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") - async def _make_callback_with_userinfo( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" @@ -965,17 +863,20 @@ async def _make_callback_with_userinfo( userinfo: the OIDC userinfo dict client_redirect_url: the URL to redirect to on success. """ + from synapse.handlers.oidc_handler import OidcSessionData + handler = hs.get_oidc_handler() - handler._exchange_code = simple_async_mock(return_value={}) - handler._parse_id_token = simple_async_mock(return_value=userinfo) - handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider = handler._providers["oidc"] + provider._exchange_code = simple_async_mock(return_value={}) + provider._parse_id_token = simple_async_mock(return_value=userinfo) + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) state = "state" - session = handler._generate_oidc_session_token( + session = handler._token_generator.generate_oidc_session_token( state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, + session_data=OidcSessionData( + idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, + ), ) request = _build_callback_request("code", state, session) @@ -1011,7 +912,7 @@ def _build_callback_request( "addCookie", "requestHeaders", "getClientIP", - "get_user_agent", + "getHeader", ] ) @@ -1020,5 +921,4 @@ def _build_callback_request( request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent return request diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 919547556b..022943a10a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase): "Frank", ) + # Set displayname to an empty string + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "" + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False @@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase): "http://my.server/me.png", ) + # Set avatar to an empty string + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, synapse.types.create_requester(self.frank), "", + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), + ) + @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 548038214b..261c7083d1 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -262,4 +262,4 @@ class SamlHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "get_user_agent"]) + return Mock(spec=["getClientIP", "getHeader"]) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4e51839d0f..686012dd25 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -1095,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # Expire both caches and repeat the request self.reactor.pump((10000.0,)) - # Repated the request, this time it should fail if the lookup fails. + # Repeat the request, this time it should fail if the lookup fails. fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv") ) @@ -1130,7 +1130,7 @@ class MatrixFederationAgentTests(unittest.TestCase): content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', ) - # The result is sucessful, but disabled delegation. + # The result is successful, but disabled delegation. r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) diff --git a/tests/http/test_client.py b/tests/http/test_client.py new file mode 100644 index 0000000000..f17c122e93 --- /dev/null +++ b/tests/http/test_client.py @@ -0,0 +1,101 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO + +from mock import Mock + +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone + +from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size + +from tests.unittest import TestCase + + +class ReadBodyWithMaxSizeTests(TestCase): + def setUp(self): + """Start reading the body, returns the response, result and proto""" + self.response = Mock() + self.result = BytesIO() + self.deferred = read_body_with_max_size(self.response, self.result, 6) + + # Fish the protocol out of the response. + self.protocol = self.response.deliverBody.call_args[0][0] + self.protocol.transport = Mock() + + def _cleanup_error(self): + """Ensure that the error in the Deferred is handled gracefully.""" + called = [False] + + def errback(f): + called[0] = True + + self.deferred.addErrback(errback) + self.assertTrue(called[0]) + + def test_no_error(self): + """A response that is NOT too large.""" + + # Start sending data. + self.protocol.dataReceived(b"12345") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"12345") + self.assertEqual(self.deferred.result, 5) + + def test_too_large(self): + """A response which is too large raises an exception.""" + + # Start sending data. + self.protocol.dataReceived(b"1234567890") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self._cleanup_error() + + def test_multiple_packets(self): + """Data should be accummulated through mutliple packets.""" + + # Start sending data. + self.protocol.dataReceived(b"12") + self.protocol.dataReceived(b"34") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234") + self.assertEqual(self.deferred.result, 4) + + def test_additional_data(self): + """A connection can receive data after being closed.""" + + # Start sending data. + self.protocol.dataReceived(b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self.protocol.transport.loseConnection.assert_called_once() + + # More data might have come in. + self.protocol.dataReceived(b"1234567890") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self._cleanup_error() diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 212484a7fe..9c52c8fdca 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase): self.pump() f = self.failureResultOf(test_d) - self.assertIsInstance(f.value, ValueError) + self.assertIsInstance(f.value, RequestSendFailed) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 22abf76515..9a56e1c14a 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -15,12 +15,14 @@ import logging import treq +from netaddr import IPSet from twisted.internet import interfaces # noqa: F401 from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.web.http import HTTPChannel +from synapse.http.client import BlacklistingReactorWrapper from synapse.http.proxyagent import ProxyAgent from tests.http import TestServerTLSConnectionFactory, get_test_https_policy @@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") + def test_http_request_via_proxy_with_blacklist(self): + # The blacklist includes the configured proxy IP. + agent = ProxyAgent( + BlacklistingReactorWrapper( + self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) + ), + self.reactor, + http_proxy=b"proxy.com:8888", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy_with_blacklist(self): + # The blacklist includes the configured proxy IP. + agent = ProxyAgent( + BlacklistingReactorWrapper( + self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) + ), + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + def _wrap_server_factory_for_tls(factory, sanlist=None): """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0504cd187e..9d22c04073 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -155,9 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.hs = hs - # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] @@ -431,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Mark the second item as safe from quarantine. _, media_id_2 = server_and_media_id_2.split("/") - self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + # Quarantine the media + url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) + channel = self.make_request("POST", url, access_token=admin_user_tok) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index aa389df12f..d0090faa4f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -371,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index c2b998cdae..51a7731693 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -181,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index fa620f97f3..a0f32c5512 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -605,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 73f8a8ec99..f48be3d65a 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9b2e4765f6..e48f8c1d7b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,8 +25,10 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError +from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.types import JsonDict from tests import unittest from tests.test_utils import make_awaitable @@ -467,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.user1 = self.register_user( - "user1", "pass1", admin=False, displayname="Name 1" - ) - self.user2 = self.register_user( - "user2", "pass2", admin=False, displayname="Name 2" - ) - def test_no_auth(self): """ Try to list users without authentication. @@ -487,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ + self._create_users(1) other_user_token = self.login("user1", "pass1") channel = self.make_request("GET", self.url, access_token=other_user_token) @@ -498,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ + self._create_users(2) + channel = self.make_request( "GET", self.url + "?deactivated=true", @@ -510,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, channel.json_body["total"]) # Check that all fields are available - for u in channel.json_body["users"]: - self.assertIn("name", u) - self.assertIn("is_guest", u) - self.assertIn("admin", u) - self.assertIn("user_type", u) - self.assertIn("deactivated", u) - self.assertIn("displayname", u) - self.assertIn("avatar_url", u) + self._check_fields(channel.json_body["users"]) def test_search_term(self): """Test that searching for a users works correctly""" @@ -548,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that users were returned self.assertTrue("users" in channel.json_body) + self._check_fields(channel.json_body["users"]) users = channel.json_body["users"] # Check that the expected number of users were returned @@ -560,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase): u = users[0] self.assertEqual(expected_user_id, u["name"]) + self._create_users(2) + + user1 = "@user1:test" + user2 = "@user2:test" + # Perform search tests - _search_test(self.user1, "er1") - _search_test(self.user1, "me 1") + _search_test(user1, "er1") + _search_test(user1, "me 1") - _search_test(self.user2, "er2") - _search_test(self.user2, "me 2") + _search_test(user2, "er2") + _search_test(user2, "me 2") - _search_test(self.user1, "er1", "user_id") - _search_test(self.user2, "er2", "user_id") + _search_test(user1, "er1", "user_id") + _search_test(user2, "er2", "user_id") # Test case insensitive - _search_test(self.user1, "ER1") - _search_test(self.user1, "NAME 1") + _search_test(user1, "ER1") + _search_test(user1, "NAME 1") - _search_test(self.user2, "ER2") - _search_test(self.user2, "NAME 2") + _search_test(user2, "ER2") + _search_test(user2, "NAME 2") - _search_test(self.user1, "ER1", "user_id") - _search_test(self.user2, "ER2", "user_id") + _search_test(user1, "ER1", "user_id") + _search_test(user2, "ER2", "user_id") _search_test(None, "foo") _search_test(None, "bar") @@ -586,6 +583,373 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid guests + channel = self.make_request( + "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid deactivated + channel = self.make_request( + "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of users with limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of users with a defined starting point (from) + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of users with a defined starting point and limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ + for u in content: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def _create_users(self, number_users: int): + """ + Create a number of users + Args: + number_users: Number of users to be created + """ + for i in range(1, number_users + 1): + self.register_user( + "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i, + ) + + +class DeactivateAccountTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) + self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote( + self.other_user + ) + + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + def test_no_auth(self): + """ + Try to deactivate users without authentication. + """ + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + url = "/_synapse/admin/v1/deactivate/@bob:test" + + channel = self.make_request("POST", url, access_token=self.other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + channel = self.make_request( + "POST", url, access_token=self.other_user_token, content=b"{}", + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_user_does_not_exist(self): + """ + Tests that deactivation for a user that does not exist returns a 404 + """ + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/deactivate/@unknown_person:test", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_erase_is_not_bool(self): + """ + If parameter `erase` is not boolean, return an error + """ + body = json.dumps({"erase": "False"}) + + channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that deactivation for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only deactivate local users", channel.json_body["error"]) + + def test_deactivate_user_erase_true(self): + """ + Test deactivating an user and set `erase` to `true` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": True}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + + def test_deactivate_user_erase_false(self): + """ + Test deactivating an user and set `erase` to `false` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": False}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + self._is_erased("@user:test", False) + + def _is_erased(self, user_id: str, expect: bool) -> None: + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + class UserRestTestCase(unittest.HomeserverTestCase): @@ -986,6 +1350,26 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test deactivating another user. """ + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) + # Deactivate user body = json.dumps({"deactivated": True}) @@ -999,6 +1383,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) # the user is deactivated, the threepid will be deleted # Get user @@ -1009,6 +1396,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self): @@ -1204,8 +1594,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1236,24 +1624,26 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_no_memberships(self): """ @@ -1284,6 +1674,49 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + def test_get_rooms_with_nonlocal_user(self): + """ + Tests that a normal lookup for rooms is successful with a non-local user + """ + + other_user_tok = self.login("user", "pass") + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + storage = self.hs.get_storage() + + # Create two rooms, one with a local user only and one with both a local + # and remote user. + self.helper.create_room_as(self.other_user, tok=other_user_tok) + local_and_remote_room_id = self.helper.create_room_as( + self.other_user, tok=other_user_tok + ) + + # Add a remote user to the room. + builder = event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@joiner:remote_hs", + "state_key": "@joiner:remote_hs", + "room_id": local_and_remote_room_id, + "content": {"membership": "join"}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success(storage.persistence.persist_event(event, context)) + + # Now get rooms + url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) + class PushersRestTestCase(unittest.HomeserverTestCase): @@ -1401,7 +1834,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -1868,8 +2300,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 1d1dc9f8a2..2672ce24c6 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,8 +15,8 @@ import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Union +from urllib.parse import urlencode from mock import Mock @@ -30,12 +30,15 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG -from tests.unittest import override_config, skip_unless +from tests.test_utils.html_parsers import TestHtmlParser +from tests.unittest import HomeserverTestCase, override_config, skip_unless try: import jwt @@ -66,6 +69,12 @@ TEST_SAML_METADATA = """ LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fรถ%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fรถ&=o"')] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -386,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): }, } + # default OIDC provider config["oidc_config"] = TEST_OIDC_CONFIG + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] return config def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + d = super().create_resource_dict() d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) return d def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" - client_redirect_url = "https://x?<abc>" - # first hit the redirect url, which should redirect to our idp picker channel = self.make_request( "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -412,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class - class FormPageParser(HTMLParser): - def __init__(self): - super().__init__() - - # the values of the hidden inputs: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] - - # the values of the radio buttons - self.radios = [] # type: List[Optional[str]] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "input": - if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": - self.radios.append(attr_dict["value"]) - elif attr_dict["type"] == "hidden": - input_name = attr_dict["name"] - assert input_name - self.hiddens[input_name] = attr_dict["value"] - - def error(_, message): - self.fail(message) - - p = FormPageParser() + p = TestHtmlParser() p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) - self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" - client_redirect_url = "https://x?<abc>" channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) @@ -467,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri = cas_uri_params["service"][0] _, service_uri_query = service_uri.split("?", 1) service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_saml(self): """If SAML is chosen, should redirect to the SAML server""" - client_redirect_url = "https://x?<abc>" - channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) @@ -489,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # the RelayState is used to carry the client redirect url saml_uri_params = urllib.parse.parse_qs(saml_uri_query) relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, client_redirect_url) + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_oidc(self): + def test_login_via_oidc(self): """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - client_redirect_url = "https://x?<abc>" + # pick the default OIDC provider channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) @@ -518,8 +522,40 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) self.assertEqual( self._get_value_from_macaroon(macaroon, "client_redirect_url"), - client_redirect_url, + TEST_CLIENT_REDIRECT_URL, + ) + + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + self.assertTrue( + channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") + ) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" @@ -667,7 +703,9 @@ class CASTestCase(unittest.HomeserverTestCase): # Deactivate the account. self.get_success( - self.deactivate_account_handler.deactivate_account(self.user_id, False) + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) ) # Request the CAS ticket. @@ -1057,3 +1095,107 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"401", channel.result) + + +@skip_unless(HAS_OIDC, "requires OIDC") +class UsernamePickerTestCase(HomeserverTestCase): + """Tests for the username picker flow of SSO login""" + + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + + config["oidc_config"] = {} + config["oidc_config"].update(TEST_OIDC_CONFIG) + config["oidc_config"]["user_mapping_provider"] = { + "config": {"display_name_template": "{{ user.displayname }}"} + } + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://x"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + + # do the start of the login flow + channel = self.helper.auth_via_oidc( + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + ) + + # that should redirect to the username picker + self.assertEqual(channel.code, 302, channel.result) + picker_url = channel.headers.getRawHeaders("Location")[0] + self.assertEqual(picker_url, "/_synapse/client/pick_username") + + # ... with a username_mapping_session cookie + cookies = {} # type: Dict[str,str] + channel.extract_cookies(cookies) + self.assertIn("username_mapping_session", cookies) + session_id = cookies["username_mapping_session"] + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map", + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = picker_url + "/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6105eac47c..d4e3165436 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias, UserID +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase): deactivate_account_handler = self.hs.get_deactivate_account_handler() self.get_success( - deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) ) # Invite another user in the room. This is needed because messages will be diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 81b7f84360..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -2,7 +2,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Optional +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -32,8 +32,9 @@ from twisted.web.server import Site from synapse.api.constants import Membership from synapse.types import JsonDict -from tests.server import FakeSite, make_request +from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -362,41 +363,128 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" + channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) - # first hit the redirect url (which will issue a cookie and state) + # expect a confirmation page + assert channel.code == 200, channel.result + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. channel = make_request( self.hs.get_reactor(), self.site, - "GET", - "/login/sso/redirect?redirectUrl=" + client_redirect_url, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) - # that will redirect to the OIDC IdP, but we skip that and go straight + assert channel.code == 200 + return channel.json_body + + def auth_via_oidc( + self, + user_info_dict: JsonDict, + client_redirect_url: Optional[str] = None, + ui_auth_session_id: Optional[str] = None, + ) -> FakeChannel: + """Perform an OIDC authentication flow via a mock OIDC provider. + + This can be used for either login or user-interactive auth. + + Starts by making a request to the relevant synapse redirect endpoint, which is + expected to serve a 302 to the OIDC provider. We then make a request to the + OIDC callback endpoint, intercepting the HTTP requests that will get sent back + to the OIDC provider. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + + Args: + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": "<remote user id>"}'. + client_redirect_url: for a login flow, the client redirect URL to pass to + the login redirect endpoint + ui_auth_session_id: if set, we will perform a UI Auth flow. The session id + of the UI auth. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + Note that the response code may be a 200, 302 or 400 depending on how things + went. + """ + + cookies = {} + + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + + # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" - # param that synapse passes to the IdP via query params, and the cookie that - # synapse passes to the client. - assert channel.code == 302 - oauth_uri = channel.headers.getRawHeaders("Location")[0] - params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) - redirect_uri = "%s?%s" % ( + # param that synapse passes to the IdP via query params, as well as the cookie + # that synapse passes to the client. + + oauth_uri_path, _ = oauth_uri.split("?", 1) + assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + "unexpected SSO URI " + oauth_uri_path + ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": "<remote user id>"}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) + params = urllib.parse.parse_qs(oauth_uri_qs) + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), ) - cookies = {} - for h in channel.headers.getRawHeaders("Set-Cookie"): - parts = h.split(";") - k, v = parts[0].split("=", maxsplit=1) - cookies[k] = v # before we hit the callback uri, stub out some methods in the http client so # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", {"sub": remote_user_id}), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -413,38 +501,85 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - redirect_uri, + callback_uri, custom_headers=[ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) + return channel - # expect a confirmation page - assert channel.code == 200 + def initiate_sso_login( + self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the login-via-sso redirect endpoint, and return the target - # fish the matrix login token out of the body of the confirmation page - m = re.search( - 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), - channel.result["body"].decode("utf-8"), - ) - assert m - login_token = m.group(1) + Assumes that exactly one SSO provider has been configured. Requires the login + servlet to be mounted. - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + Args: + client_redirect_url: the client redirect URL to pass to the login redirect + endpoint + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets redirected to (ie, the SSO server) + """ + params = {} + if client_redirect_url: + params["redirectUrl"] = client_redirect_url + + # hit the redirect url (which will issue a cookie and state) channel = make_request( self.hs.get_reactor(), self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, + "GET", + "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), ) - assert channel.code == 200 - return channel.json_body + + assert channel.code == 302 + channel.extract_cookies(cookies) + return channel.headers.getRawHeaders("Location")[0] + + def initiate_sso_ui_auth( + self, ui_auth_session_id: str, cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the ui-auth-via-sso endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the + AuthRestServlet to be mounted. + + Args: + ui_auth_session_id: the session id of the UI auth + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets linked to (ie, the SSO server) + """ + sso_redirect_endpoint = ( + "/_matrix/client/r0/auth/m.login.sso/fallback/web?" + + urllib.parse.urlencode({"session": ui_auth_session_id}) + ) + # hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint + ) + # that should serve a confirmation page + assert channel.code == 200, channel.text_body + channel.extract_cookies(cookies) + + # parse the confirmation page to fish out the link. + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + assert len(p.links) == 1, "not exactly one link in confirmation page" + oauth_uri = p.links[0] + return oauth_uri # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -453,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index bb91e0c331..a6488a3d29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector +# Copyright 2020-2021 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Union from twisted.internet.defer import succeed @@ -386,6 +386,44 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_via_sso(self): + """Test a successful UI Auth flow via SSO + + This includes: + * hitting the UIA SSO redirect endpoint + * checking it serves a confirmation page which links to the OIDC provider + * calling back to the synapse oidc callback + * checking that the original operation succeeds + """ + + # log the user in + remote_user_id = UserID.from_string(self.user).localpart + login_resp = self.helper.login_via_oidc(remote_user_id) + self.assertEqual(login_resp["user_id"], self.user) + + # initiate a UI Auth process by attempting to delete the device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # check that SSO is offered + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + + # run the UIA-via-SSO flow + session_id = channel.json_body["session"] + channel = self.helper.auth_via_oidc( + {"sub": remote_user_id}, ui_auth_session_id=session_id + ) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + + # and now the delete request should succeed. + self.delete_device( + self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, + ) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self): login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] @@ -419,3 +457,32 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertIn({"stages": ["m.login.password"]}, flows) self.assertIn({"stages": ["m.login.sso"]}, flows) self.assertEqual(len(flows), 2) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_fails_for_incorrect_sso_user(self): + """If the user tries to authenticate with the wrong SSO user, they get an error + """ + # log the user in + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + # start a UI Auth flow by attempting to delete a device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + session_id = channel.json_body["session"] + + # do the OIDC auth, but auth as the wrong user + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) + + # that should return a failure message + self.assertSubstring("We were unable to validate", channel.text_body) + + # ... and the delete op should now fail with a 403 + self.delete_device( + self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ae2b32b131..a6c6985173 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): config = self.default_config() config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} config["max_image_pixels"] = 2000000 provider_config = { @@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): + """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) def test_thumbnail_scale(self): + """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) + def test_invalid_type(self): + """An invalid thumbnail type is never available.""" + self._test_thumbnail("invalid", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} + ) + def test_no_thumbnail_crop(self): + """ + Override the config to generate only scaled thumbnails, but request a cropped one. + """ + self._test_thumbnail("crop", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} + ) + def test_no_thumbnail_scale(self): + """ + Override the config to generate only cropped thumbnails, but request a scaled one. + """ + self._test_thumbnail("scale", None, False) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 14de0921be..c5e44af9f7 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,12 +40,3 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) - - def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - - channel = self.make_request( - "GET", "/.well-known/matrix/client", shorthand=False - ) - - self.assertEqual(channel.code, 404) diff --git a/tests/server.py b/tests/server.py index 7d1ad362c4..5a85d5fe7f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Iterable, Optional, Tuple, Union +from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -51,9 +51,21 @@ class FakeChannel: @property def json_body(self): - if not self.result: - raise Exception("No result yet.") - return json.loads(self.result["body"].decode("utf8")) + return json.loads(self.text_body) + + @property + def text_body(self) -> str: + """The body of the result, utf-8-decoded. + + Raises an exception if the request has not yet completed. + """ + if not self.is_finished: + raise Exception("Request not yet completed") + return self.result["body"].decode("utf8") + + def is_finished(self) -> bool: + """check if the response has been completely received""" + return self.result.get("done", False) @property def code(self): @@ -62,7 +74,7 @@ class FakeChannel: return int(self.result["code"]) @property - def headers(self): + def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") h = Headers() @@ -124,7 +136,7 @@ class FakeChannel: self._reactor.run() x = 0 - while not self.result.get("done"): + while not self.is_finished(): # If there's a producer, tell it to resume producing so we get content if self._producer: self._producer.resumeProducing() @@ -136,6 +148,16 @@ class FakeChannel: self._reactor.advance(0.1) + def extract_cookies(self, cookies: MutableMapping[str, str]) -> None: + """Process the contents of any Set-Cookie headers in the response + + Any cookines found are added to the given dict + """ + for h in self.headers.getRawHeaders("Set-Cookie"): + parts = h.split(";") + k, v = parts[0].split("=", maxsplit=1) + cookies[k] = v + class FakeSite: """ diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py new file mode 100644 index 0000000000..0c46ad595b --- /dev/null +++ b/tests/storage/test_event_chain.py @@ -0,0 +1,741 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Set, Tuple + +from twisted.trial import unittest + +from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.storage.databases.main.events import _LinkMap +from synapse.types import create_requester + +from tests.unittest import HomeserverTestCase + + +class EventChainStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self._next_stream_ordering = 1 + + def test_simple(self): + """Test that the example in `docs/auth_chain_difference_algorithm.md` + works. + """ + + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + power_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + bob_join_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[ + create.event_id, + alice_join.event_id, + power_2.event_id, + ], + ) + ) + + events = [ + create, + bob_join, + power, + alice_invite, + alice_join, + bob_join_2, + power_2, + alice_join2, + ] + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + (bob_join_2, power), + (alice_join2, power_2), + ] + + self.persist(events) + chain_map, link_map = self.fetch_chains(events) + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + # Test that everything can reach the create event, but the create event + # can't reach anything. + for event in events[1:]: + self.assertTrue( + link_map.exists_path_from( + chain_map[event.event_id], chain_map[create.event_id] + ), + ) + + self.assertFalse( + link_map.exists_path_from( + chain_map[create.event_id], chain_map[event.event_id], + ), + ) + + def test_out_of_order_events(self): + """Test that we handle persisting events that we don't have the full + auth chain for yet (which should only happen for out of band memberships). + """ + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + # First persist the base room. + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + self.persist([create, bob_join, power]) + + # Now persist an invite and a couple of memberships out of order. + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_join.event_id, power.event_id], + ) + ) + + self.persist([alice_join]) + self.persist([alice_join2]) + self.persist([alice_invite]) + + # The end result should be sane. + events = [create, bob_join, power, alice_invite, alice_join] + + chain_map, link_map = self.fetch_chains(events) + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + ] + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + def persist( + self, events: List[EventBase], + ): + """Persist the given events and check that the links generated match + those given. + """ + + persist_events_store = self.hs.get_datastores().persist_events + + for e in events: + e.internal_metadata.stream_ordering = self._next_stream_ordering + self._next_stream_ordering += 1 + + def _persist(txn): + # We need to persist the events to the events and state_events + # tables. + persist_events_store._store_event_txn(txn, [(e, {}) for e in events]) + + # Actually call the function that calculates the auth chain stuff. + persist_events_store._persist_event_auth_chain_txn(txn, events) + + self.get_success( + persist_events_store.db_pool.runInteraction("_persist", _persist,) + ) + + def fetch_chains( + self, events: List[EventBase] + ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: + + # Fetch the map from event ID -> (chain ID, sequence number) + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in events], + retcols=("event_id", "chain_id", "sequence_number"), + keyvalues={}, + ) + ) + + chain_map = { + row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows + } + + # Fetch all the links and pass them to the _LinkMap. + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chain_links", + column="origin_chain_id", + iterable=[chain_id for chain_id, _ in chain_map.values()], + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + keyvalues={}, + ) + ) + + link_map = _LinkMap() + for row in rows: + added = link_map.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + ) + + # We shouldn't have persisted any redundant links + self.assertTrue(added) + + return chain_map, link_map + + +class LinkMapTestCase(unittest.TestCase): + def test_simple(self): + """Basic tests for the LinkMap. + """ + link_map = _LinkMap() + + link_map.add_link((1, 1), (2, 1), new=False) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) + self.assertCountEqual(link_map.get_additions(), []) + self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) + self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) + self.assertTrue(link_map.exists_path_from((1, 5), (1, 1))) + self.assertFalse(link_map.exists_path_from((1, 1), (1, 5))) + + # Attempting to add a redundant link is ignored. + self.assertFalse(link_map.add_link((1, 4), (2, 1))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + + # Adding new non-redundant links works + self.assertTrue(link_map.add_link((1, 3), (2, 3))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertTrue(link_map.add_link((2, 5), (1, 3))) + self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) + + +class EventChainBackgroundUpdateTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.requester = create_requester(self.user_id) + + def _generate_room(self) -> Tuple[str, List[Set[str]]]: + """Insert a room without a chain cover index. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Mark the room as not having a chain cover index + self.get_success( + self.store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + desc="test", + ) + ) + + # Create a fork in the DAG with different events. + event_handler = self.hs.get_event_creation_handler() + latest_event_ids = self.get_success( + self.store.get_prev_events_for_room(room_id) + ) + event, context = self.get_success( + event_handler.create_event( + self.requester, + { + "type": "some_state_type", + "state_key": "", + "content": {}, + "room_id": room_id, + "sender": self.user_id, + }, + prev_event_ids=latest_event_ids, + ) + ) + self.get_success( + event_handler.handle_new_client_event(self.requester, event, context) + ) + state1 = set(self.get_success(context.get_current_state_ids()).values()) + + event, context = self.get_success( + event_handler.create_event( + self.requester, + { + "type": "some_state_type", + "state_key": "", + "content": {}, + "room_id": room_id, + "sender": self.user_id, + }, + prev_event_ids=latest_event_ids, + ) + ) + self.get_success( + event_handler.handle_new_client_event(self.requester, event, context) + ) + state2 = set(self.get_success(context.get_current_state_ids()).values()) + + # Delete the chain cover info. + + def _delete_tables(txn): + txn.execute("DELETE FROM event_auth_chains") + txn.execute("DELETE FROM event_auth_chain_links") + + self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) + + return room_id, [state1, state2] + + def test_background_update_single_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id, + states, + ) + ) + + def test_background_update_multiple_rooms(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + # Create a room + room_id1, states1 = self._generate_room() + room_id2, states2 = self._generate_room() + room_id3, states2 = self._generate_room() + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id1, + states1, + ) + ) + + def test_background_update_single_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id, + states, + ) + ) + + def test_background_update_multiple_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create the rooms + room_id1, _ = self._generate_room() + room_id2, _ = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id1, event_type="m.test", body={"index": i}, tok=self.token + ) + + for i in range(0, 150): + self.helper.send_state( + room_id2, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 482506d731..9d04a066d8 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import attr +from parameterized import parameterized + +from synapse.events import _EventInternalMetadata + import tests.unittest import tests.utils @@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - def test_auth_difference(self): + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } + # Mark the room as not having a cover index + + def store_room(txn): + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": use_chain_cover_index, + }, + ) + + self.get_success(self.store.db_pool.runInteraction("store_room", store_room)) + # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn, event_id, stream_ordering): + def insert_event(txn): + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + ], + ) + + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) + self.assertSetEqual(difference, set()) + + def test_auth_difference_partial_cover(self): + """Test that we correctly handle rooms where not all events have a chain + cover calculated. This can happen in some obscure edge cases, including + during the background update that calculates the chain cover for old + rooms. + """ + + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ยด | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } - depth = depth_map[event_id] + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + def insert_event(txn): + # First insert the room and mark it as having a chain cover. self.store.db_pool.simple_insert_txn( txn, - table="events", - values={ - "event_id": event_id, + "rooms", + { "room_id": room_id, - "depth": depth, - "topological_ordering": depth, - "type": "m.test", - "processed": True, - "outlier": False, - "stream_ordering": stream_ordering, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": True, }, ) - self.store.db_pool.simple_insert_many_txn( + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + # Insert all events apart from 'B' + self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, - table="event_auth", - values=[ - {"event_id": event_id, "room_id": room_id, "auth_id": a} - for a in auth_graph[event_id] + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + if event_id != "b" ], ) - next_stream_ordering = 0 - for event_id in auth_graph: - next_stream_ordering += 1 - self.get_success( - self.store.db_pool.runInteraction( - "insert", insert_event, event_id, next_stream_ordering - ) + # Now we insert the event 'B' without a chain cover, by temporarily + # pretending the room doesn't have a chain cover. + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, [FakeEvent("b", room_id, auth_graph["b"])], + ) + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": True}, ) + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + # Now actually test that various combinations give the right result: difference = self.get_success( @@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_auth_chain_difference(room_id, [{"a"}]) ) self.assertSetEqual(difference, set()) + + +@attr.s +class FakeEvent: + event_id = attr.ib() + room_id = attr.ib() + auth_events = attr.ib() + + type = "foo" + state_key = "foo" + + internal_metadata = _EventInternalMetadata({}) + + def auth_event_ids(self): + return self.auth_events + + def is_state(self): + return True diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index cc0612cf65..3e2fd4da01 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, ) @@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, positive=False, @@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) + + +class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar1 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + txn.execute( + """ + CREATE TABLE foobar2 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + stream_name="test_stream", + instance_name=instance_name, + tables=[ + ("foobar1", "instance_name", "stream_id"), + ("foobar2", "instance_name", "stream_id"), + ], + sequence_name="foobar_seq", + writers=writers, + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def _insert_rows( + self, + table: str, + instance_name: str, + number: int, + update_stream_table: bool = True, + ): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), + (instance_name,), + ) + if update_stream_table: + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def test_load_existing_stream(self): + """Test creating ID gens with multiple tables that have rows from after + the position in `stream_positions` table. + """ + self._insert_rows("foobar1", "first", 3) + self._insert_rows("foobar2", "second", 3) + self._insert_rows("foobar2", "second", 1, update_stream_table=False) + + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) + + # The first ID gen will notice that it can advance its token to 7 as it + # has no in progress writes... + self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + + # ... but the second ID gen doesn't know that. + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 3fd0a38cf5..ea63bd56b4 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase): ), ) + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) @@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase): ) ), ) + + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_avatar_url(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ) + ) diff --git a/tests/test_types.py b/tests/test_types.py index 480bea1bdc..acdeea7a09 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -58,6 +58,10 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): self.assertEquals(room.to_string(), "#channel:my.domain") + def test_validate(self): + id_string = "#test:domain,test" + self.assertFalse(RoomAlias.is_valid(id_string)) + class GroupIDTestCase(unittest.TestCase): def test_parse(self): diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py new file mode 100644 index 0000000000..ad563eb3f0 --- /dev/null +++ b/tests/test_utils/html_parsers.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from html.parser import HTMLParser +from typing import Dict, Iterable, List, Optional, Tuple + + +class TestHtmlParser(HTMLParser): + """A generic HTML page parser which extracts useful things from the HTML""" + + def __init__(self): + super().__init__() + + # a list of links found in the doc + self.links = [] # type: List[str] + + # the values of any hidden <input>s: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of any radio buttons: map from name to list of values + self.radios = {} # type: Dict[str, List[Optional[str]]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + elif tag == "input": + input_name = attr_dict.get("name") + if attr_dict["type"] == "radio": + assert input_name + self.radios.setdefault(input_name, []).append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + raise AssertionError(message) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 0ab0a91483..522c8061f9 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.iterutils import chunk_seq +from typing import Dict, List + +from synapse.util.iterutils import chunk_seq, sorted_topologically from tests.unittest import TestCase @@ -45,3 +47,48 @@ class ChunkSeqTests(TestCase): self.assertEqual( list(parts), [], ) + + +class SortTopologically(TestCase): + def test_empty(self): + "Test that an empty graph works correctly" + + graph = {} # type: Dict[int, List[int]] + self.assertEqual(list(sorted_topologically([], graph)), []) + + def test_handle_empty_graph(self): + "Test that a graph where a node doesn't have an entry is treated as empty" + + graph = {} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + + def test_disconnected(self): + "Test that a graph with no edges work" + + graph = {1: [], 2: []} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + + def test_linear(self): + "Test that a simple `4 -> 3 -> 2 -> 1` graph works" + + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_subset(self): + "Test that only sorting a subset of the graph works" + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) + + def test_fork(self): + "Test that a forked graph works" + graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + + # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should + # always get the same one. + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) diff --git a/tests/utils.py b/tests/utils.py index 977eeaf6ee..09614093bc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,7 +159,6 @@ def default_config(name, parse=False): "remote": {"per_second": 10000, "burst_count": 10000}, }, "saml2_enabled": False, - "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, diff --git a/tox.ini b/tox.ini index 297136fcc5..1a3489344f 100644 --- a/tox.ini +++ b/tox.ini @@ -24,7 +24,8 @@ deps = # install the "enum34" dependency of cryptography. pip>=10 -# directories/files we run the linters on +# directories/files we run the linters on. +# if you update this list, make sure to do the same in scripts-dev/lint.sh lint_targets = setup.py synapse @@ -103,6 +104,9 @@ usedevelop=true [testenv:py35-old] skip_install=True deps = + # Ensure a version of setuptools that supports Python 3.5 is installed. + setuptools < 51.0.0 + # Old automat version for Twisted Automat == 0.3.0 |